diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py
index 3110454..f165d89 100644
--- a/ai_diplomacy/agent.py
+++ b/ai_diplomacy/agent.py
@@ -10,7 +10,7 @@ import ast
from ..config import config
# Assuming BaseModelClient is importable from clients.py in the same directory
-from .clients import BaseModelClient, load_model_client
+from .clients import BaseModelClient
# Import load_prompt and the new logging wrapper from utils
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
@@ -336,7 +336,7 @@ class DiplomacyAgent:
)
# If even the last resort doesn't give a dict, return empty dict
return {}
- except Exception as e:
+ except Exception:
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
return {}
@@ -965,7 +965,7 @@ class DiplomacyAgent:
other_powers = [p for p in game.powers if p != power_name]
# Create a readable board state string from the board_state dict
- board_state_str = f"Board State:\n"
+ board_state_str = "Board State:\n"
for p_name, power_data in board_state.get("powers", {}).items():
# Get units and centers from the board state
units = power_data.get("units", [])
diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py
index fc44a73..07e5a43 100644
--- a/ai_diplomacy/clients.py
+++ b/ai_diplomacy/clients.py
@@ -1,41 +1,41 @@
import os
import json
-from json import JSONDecodeError
import re
import logging
import ast # For literal_eval in JSON fallback parsing
import aiohttp # For direct HTTP requests to Responses API
-from typing import List, Dict, Optional, Any, Tuple, NamedTuple
+from typing import List, Dict, Optional, Tuple, NamedTuple
from dotenv import load_dotenv
# Use Async versions of clients
from openai import AsyncOpenAI
-from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
+from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
from anthropic import AsyncAnthropic
import asyncio
import requests
import google.generativeai as genai
from together import AsyncTogether
-from together.error import APIError as TogetherAPIError # For specific error handling
+from together.error import APIError as TogetherAPIError # For specific error handling
from ..config import config
-from diplomacy.engine.message import GLOBAL
from .game_history import GameHistory
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
+
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
# Moved formatter imports to avoid circular import - imported locally where needed
# set logger back to just info
logger = logging.getLogger("client")
-logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
+logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
# Note: BasicConfig might conflict if already configured in lm_game. Keep client-specific for now.
# logging.basicConfig(level=logging.DEBUG) # Might be redundant if lm_game configures root
load_dotenv()
+
##############################################################################
# 1) Base Interface
##############################################################################
@@ -52,7 +52,7 @@ class BaseModelClient:
self.model_name = model_name
self.prompts_dir = prompts_dir
# Load a default initially, can be overwritten by set_system_prompt
- self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
+ self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir)
self.max_tokens = 16000 # default unless overridden
def set_system_prompt(self, content: str):
@@ -76,13 +76,13 @@ class BaseModelClient:
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
- conversation_text: str, # This is GameHistory
+ conversation_text: str, # This is GameHistory
model_error_stats: dict,
log_file_path: str,
phase: str,
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None, # Added
+ agent_private_diary_str: Optional[str] = None, # Added
) -> List[str]:
"""
1) Builds the prompt with conversation context if available
@@ -90,7 +90,7 @@ class BaseModelClient:
3) Parses JSON block
"""
# The 'conversation_text' parameter was GameHistory. Renaming for clarity.
- game_history_obj = conversation_text
+ game_history_obj = conversation_text
prompt = construct_order_generation_prompt(
system_prompt=self.system_prompt,
@@ -98,7 +98,7 @@ class BaseModelClient:
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
- game_history=game_history_obj, # Pass GameHistory object
+ game_history=game_history_obj, # Pass GameHistory object
agent_goals=agent_goals,
agent_relationships=agent_relationships,
agent_private_diary_str=agent_private_diary_str,
@@ -108,7 +108,7 @@ class BaseModelClient:
raw_response = ""
# Initialize success status. Will be updated based on outcome.
success_status = "Failure: Initialized"
- parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback
+ parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback
try:
# Call LLM using the logging wrapper
@@ -118,36 +118,29 @@ class BaseModelClient:
log_file_path=log_file_path,
power_name=power_name,
phase=phase,
- response_type='order', # Context for run_llm_and_log's own error logging
- temperature=0
- )
- logger.debug(
- f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
+ response_type="order", # Context for run_llm_and_log's own error logging
+ temperature=0,
)
+ logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
+
# Format the natural language response into structured format
formatted_response = await format_with_gemini_flash(
- raw_response,
- FORMAT_ORDERS,
- power_name=power_name,
- phase=phase,
- log_file_path=log_file_path
+ raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
formatted_response = raw_response
-
+
# Attempt to parse the final "orders" from the formatted response
move_list = self._extract_moves(formatted_response, power_name)
if not move_list:
- logger.warning(
- f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
- )
+ logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
if model_error_stats is not None and self.model_name in model_error_stats:
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
model_error_stats[self.model_name]["order_decoding_errors"] += 1
@@ -163,18 +156,18 @@ class BaseModelClient:
max_invalid_to_log = 5
display_invalid_moves = invalid_moves_list[:max_invalid_to_log]
omitted_count = len(invalid_moves_list) - len(display_invalid_moves)
-
+
invalid_moves_str = ", ".join(display_invalid_moves)
if omitted_count > 0:
invalid_moves_str += f", ... ({omitted_count} more)"
-
+
success_status = f"Failure: Invalid LLM Moves ({len(invalid_moves_list)}): {invalid_moves_str}"
# If some moves were validated despite others being invalid, it's still not a full 'Success'
# because the LLM didn't provide a fully usable set of orders without intervention/fallbacks.
# The fallback_orders logic within _validate_orders might fill in missing pieces,
# but the key is that the LLM *proposed* invalid moves.
- if not validated_moves: # All LLM moves were invalid
- logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}")
+ if not validated_moves: # All LLM moves were invalid
+ logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}")
else:
logger.info(f"[{power_name}] Some LLM-proposed moves were invalid. Using fallbacks/validated. Invalid: {invalid_moves_list}")
else:
@@ -186,16 +179,16 @@ class BaseModelClient:
# Fallback is already set to parsed_orders_for_return
finally:
# Log the attempt regardless of outcome
- if log_file_path: # Only log if a path is provided
+ if log_file_path: # Only log if a path is provided
log_llm_response(
log_file_path=log_file_path,
model_name=self.model_name,
power_name=power_name,
phase=phase,
- response_type="order_generation", # Specific type for CSV logging
- raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
+ response_type="order_generation", # Specific type for CSV logging
+ raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
raw_response=raw_response,
- success=success_status
+ success=success_status,
# token_usage and cost can be added later if available and if log_llm_response supports them
)
return parsed_orders_for_return
@@ -216,57 +209,43 @@ class BaseModelClient:
if not matches:
# Some LLMs might not put the colon or might have triple backtick fences.
- logger.debug(
- f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
- )
+ logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
matches = re.search(pattern_alt, raw_response, re.DOTALL)
-
+
if not matches:
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
- logger.debug(
- f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
- )
+ logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
if not matches:
- logger.debug(
- f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
- )
+ logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
# 2) If still no match, check for triple-backtick code fences containing JSON
if not matches:
code_fence_pattern = r"```json\n(.*?)\n```"
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
if matches:
- logger.debug(
- f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
- )
-
+ logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
+
# 2b) Also try plain ``` code fences without json marker
if not matches:
code_fence_plain = r"```\n(.*?)\n```"
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
if matches:
- logger.debug(
- f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
- )
-
+ logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
+
# 2c) Try to find bare JSON object anywhere in the response
if not matches:
- logger.debug(
- f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
- )
+ logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
# Look for a JSON object that contains "orders" key
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
if matches:
- logger.debug(
- f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
- )
+ logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
# 3) Attempt to parse JSON if we found anything
json_text = None
@@ -283,9 +262,7 @@ class BaseModelClient:
json_text = json_text.strip()
if not json_text:
- logger.debug(
- f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
- )
+ logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
return None
# 3a) Try JSON loading
@@ -293,14 +270,12 @@ class BaseModelClient:
data = json.loads(json_text)
return data.get("orders", None)
except json.JSONDecodeError as e:
- logger.warning(
- f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
- )
-
+ logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
+
# Try to fix common JSON issues
try:
# Remove trailing commas
- fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text)
+ fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
# Fix single quotes to double quotes
fixed_json = fixed_json.replace("'", '"')
# Try parsing again
@@ -308,14 +283,12 @@ class BaseModelClient:
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
return data.get("orders", None)
except json.JSONDecodeError:
- logger.warning(
- f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
- )
-
+ logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
+
# Try to remove inline comments (// style)
try:
# Remove // comments from each line
- lines = json_text.split('\n')
+ lines = json_text.split("\n")
cleaned_lines = []
for line in lines:
# Find // that's not inside quotes
@@ -326,33 +299,31 @@ class BaseModelClient:
if escape_next:
escape_next = False
continue
- if char == '\\':
+ if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_quotes = not in_quotes
- if not in_quotes and line[i:i+2] == '//':
+ if not in_quotes and line[i : i + 2] == "//":
comment_pos = i
break
-
+
if comment_pos >= 0:
# Remove comment but keep any trailing comma
cleaned_line = line[:comment_pos].rstrip()
else:
cleaned_line = line
cleaned_lines.append(cleaned_line)
-
- comment_free_json = '\n'.join(cleaned_lines)
+
+ comment_free_json = "\n".join(cleaned_lines)
# Also remove trailing commas after comment removal
- comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json)
-
+ comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
+
data = json.loads(comment_free_json)
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
return data.get("orders", None)
except json.JSONDecodeError:
- logger.warning(
- f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
- )
+ logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
# 3b) Attempt bracket fallback: we look for the substring after "orders"
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
@@ -366,30 +337,26 @@ class BaseModelClient:
if isinstance(moves, list):
return moves
except Exception as e2:
- logger.warning(
- f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
- )
+ logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
# If all attempts failed
return None
-
- def _validate_orders(
- self, moves: List[str], possible_orders: Dict[str, List[str]]
- ) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
+
+ def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
"""
Filter out invalid moves, fill missing with HOLD, else fallback.
Returns a tuple: (validated_moves, invalid_moves_found)
"""
logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}")
validated = []
- invalid_moves_found = [] # ADDED: To collect invalid moves
+ invalid_moves_found = [] # ADDED: To collect invalid moves
used_locs = set()
if not isinstance(moves, list):
logger.debug(f"[{self.model_name}] Moves not a list, fallback.")
# Return fallback and empty list for invalid_moves_found as no specific LLM moves were processed
- return self.fallback_orders(possible_orders), []
-
+ return self.fallback_orders(possible_orders), []
+
for move_str in moves:
# Check if it's in possible orders
if any(move_str in loc_orders for loc_orders in possible_orders.values()):
@@ -399,21 +366,21 @@ class BaseModelClient:
used_locs.add(parts[1][:3])
else:
logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}")
- invalid_moves_found.append(move_str) # ADDED: Collect invalid move
+ invalid_moves_found.append(move_str) # ADDED: Collect invalid move
# Fill missing with hold
for loc, orders_list in possible_orders.items():
if loc not in used_locs and orders_list:
hold_candidates = [o for o in orders_list if o.endswith("H")]
- validated.append(
- hold_candidates[0] if hold_candidates else orders_list[0]
- )
+ validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
- if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
+ if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
return self.fallback_orders(possible_orders), []
- elif not validated and invalid_moves_found: # All LLM moves were invalid
- logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}")
+ elif not validated and invalid_moves_found: # All LLM moves were invalid
+ logger.warning(
+ f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
+ )
# We return empty list for validated, but the invalid_moves_found list is populated
return self.fallback_orders(possible_orders), invalid_moves_found
@@ -442,9 +409,8 @@ class BaseModelClient:
# log_file_path: str, # Not used directly by build_context_prompt
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None, # Added
+ agent_private_diary_str: Optional[str] = None, # Added
) -> str:
-
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
context = self.build_context_prompt(
@@ -455,7 +421,7 @@ class BaseModelClient:
game_history,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
- agent_private_diary=agent_private_diary_str, # Pass diary string
+ agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
@@ -472,7 +438,7 @@ class BaseModelClient:
# log_file_path: str, # Not used directly by build_context_prompt
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None, # Added
+ agent_private_diary_str: Optional[str] = None, # Added
) -> str:
# MINIMAL CHANGE: Just change to load unformatted version conditionally
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
@@ -486,19 +452,19 @@ class BaseModelClient:
game_history,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
- agent_private_diary=agent_private_diary_str, # Pass diary string
+ agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
-
+
# KEEP ORIGINAL: Get recent messages targeting this power to prioritize responses
recent_messages_to_power = game_history.get_recent_messages_to_power(power_name, limit=3)
-
+
# KEEP ORIGINAL: Debug logging to verify messages
logger.info(f"[{power_name}] Found {len(recent_messages_to_power)} high priority messages to respond to")
if recent_messages_to_power:
for i, msg in enumerate(recent_messages_to_power):
- logger.info(f"[{power_name}] Priority message {i+1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...")
-
+ logger.info(f"[{power_name}] Priority message {i + 1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...")
+
# KEEP ORIGINAL: Add a section for unanswered messages
unanswered_messages = "\n\nRECENT MESSAGES REQUIRING YOUR ATTENTION:\n"
if recent_messages_to_power:
@@ -506,25 +472,32 @@ class BaseModelClient:
unanswered_messages += f"\nFrom {msg['sender']} in {msg['phase']}: {msg['content']}\n"
else:
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
-
+
final_prompt = context + unanswered_messages + "\n\n" + instructions
- final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
+ final_prompt = (
+ final_prompt.replace("AUSTRIA", "Austria")
+ .replace("ENGLAND", "England")
+ .replace("FRANCE", "France")
+ .replace("GERMANY", "Germany")
+ .replace("ITALY", "Italy")
+ .replace("RUSSIA", "Russia")
+ .replace("TURKEY", "Turkey")
+ )
return final_prompt
- async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
+ async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
self,
game,
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history: GameHistory,
- game_phase: str, # Used for logging
- log_file_path: str, # Used for logging
+ game_phase: str, # Used for logging
+ log_file_path: str, # Used for logging
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None, # Added
+ agent_private_diary_str: Optional[str] = None, # Added
) -> str:
-
prompt = self.build_planning_prompt(
game,
board_state,
@@ -535,7 +508,7 @@ class BaseModelClient:
# log_file_path, # Not passed to build_planning_prompt directly
agent_goals=agent_goals,
agent_relationships=agent_relationships,
- agent_private_diary_str=agent_private_diary_str, # Pass diary string
+ agent_private_diary_str=agent_private_diary_str, # Pass diary string
)
# Call LLM using the logging wrapper
@@ -544,12 +517,12 @@ class BaseModelClient:
prompt=prompt,
log_file_path=log_file_path,
power_name=power_name,
- phase=game_phase, # Use game_phase for logging
- response_type='plan_reply', # Changed from 'plan' to avoid confusion
+ phase=game_phase, # Use game_phase for logging
+ response_type="plan_reply", # Changed from 'plan' to avoid confusion
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
return raw_response
-
+
async def get_conversation_reply(
self,
game,
@@ -559,18 +532,18 @@ class BaseModelClient:
game_history: GameHistory,
game_phase: str,
log_file_path: str,
- active_powers: Optional[List[str]] = None,
+ active_powers: Optional[List[str]] = None,
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None,
+ agent_private_diary_str: Optional[str] = None,
) -> List[Dict[str, str]]:
"""
Generates a negotiation message, considering agent state.
"""
- raw_input_prompt = "" # Initialize for finally block
- raw_response = "" # Initialize for finally block
- success_status = "Failure: Initialized" # Default status
- messages_to_return = [] # Initialize to ensure it's defined
+ raw_input_prompt = "" # Initialize for finally block
+ raw_response = "" # Initialize for finally block
+ success_status = "Failure: Initialized" # Default status
+ messages_to_return = [] # Initialize to ensure it's defined
try:
raw_input_prompt = self.build_conversation_prompt(
@@ -581,7 +554,7 @@ class BaseModelClient:
game_history,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
- agent_private_diary_str=agent_private_diary_str,
+ agent_private_diary_str=agent_private_diary_str,
)
logger.debug(f"[{self.model_name}] Conversation prompt for {power_name}:\n{raw_input_prompt}")
@@ -591,33 +564,30 @@ class BaseModelClient:
prompt=raw_input_prompt,
log_file_path=log_file_path,
power_name=power_name,
- phase=game_phase,
- response_type='negotiation', # For run_llm_and_log's internal context
+ phase=game_phase,
+ response_type="negotiation", # For run_llm_and_log's internal context
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
-
+
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
+
# Format the natural language response into structured JSON
formatted_response = await format_with_gemini_flash(
- raw_response,
- FORMAT_CONVERSATION,
- power_name=power_name,
- phase=game_phase,
- log_file_path=log_file_path
+ raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
formatted_response = raw_response
-
+
parsed_messages = []
json_blocks = []
json_decode_error_occurred = False
-
+
# For formatted response, we expect a clean JSON array
- try:
+ try:
data = json.loads(formatted_response)
if isinstance(data, list):
parsed_messages = data
@@ -628,14 +598,14 @@ class BaseModelClient:
logger.warning(f"[{self.model_name}] Failed to parse formatted response as JSON, falling back to regex")
# Fall back to original parsing logic using formatted_response
raw_response = formatted_response
-
+
# Original parsing logic as fallback
if not parsed_messages:
# Attempt to find blocks enclosed in {{...}}
- double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL)
+ double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
if double_brace_blocks:
# If {{...}} blocks are found, assume each is a self-contained JSON object
- json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks])
+ json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
else:
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
@@ -650,10 +620,10 @@ class BaseModelClient:
json_blocks = [json.dumps(data)]
except json.JSONDecodeError:
# If parsing the whole block fails, fall back to regex for individual objects
- json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL)
+ json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
else:
# If no markdown block, fall back to regex for any JSON object in the response
- json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL)
+ json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
# Process json_blocks if we have them from fallback parsing
if not parsed_messages and json_blocks:
@@ -661,13 +631,13 @@ class BaseModelClient:
try:
cleaned_block = block.strip()
# Attempt to fix common JSON issues like trailing commas before parsing
- cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block)
+ cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
parsed_message = json.loads(cleaned_block)
parsed_messages.append(parsed_message)
except json.JSONDecodeError as e:
logger.warning(f"[{self.model_name}] Failed to parse JSON block {block_index} for {power_name}: {e}")
json_decode_error_occurred = True
-
+
if not parsed_messages:
logger.warning(f"[{self.model_name}] No valid messages found in response for {power_name}")
success_status = "Success: No messages found"
@@ -684,7 +654,7 @@ class BaseModelClient:
else:
logger.warning(f"[{self.model_name}] Invalid message structure for {power_name}")
parsed_messages = validated_messages
-
+
# Set final status and return value
if parsed_messages:
success_status = "Success: Messages extracted"
@@ -695,11 +665,11 @@ class BaseModelClient:
logger.debug(f"[{self.model_name}] Validated conversation replies for {power_name}: {messages_to_return}")
# return messages_to_return # Return will happen in finally block or after
-
+
except Exception as e:
logger.error(f"[{self.model_name}] Error in get_conversation_reply for {power_name}: {e}", exc_info=True)
success_status = f"Failure: Exception ({type(e).__name__})"
- messages_to_return = [] # Ensure empty list on general exception
+ messages_to_return = [] # Ensure empty list on general exception
finally:
if log_file_path:
log_llm_response(
@@ -710,11 +680,11 @@ class BaseModelClient:
response_type="negotiation_message",
raw_input_prompt=raw_input_prompt,
raw_response=raw_response,
- success=success_status
+ success=success_status,
)
return messages_to_return
- async def get_plan( # This is the original get_plan, now distinct from get_planning_reply
+ async def get_plan( # This is the original get_plan, now distinct from get_planning_reply
self,
game,
board_state,
@@ -724,14 +694,14 @@ class BaseModelClient:
log_file_path: str,
agent_goals: Optional[List[str]] = None,
agent_relationships: Optional[Dict[str, str]] = None,
- agent_private_diary_str: Optional[str] = None, # Added
+ agent_private_diary_str: Optional[str] = None, # Added
) -> str:
"""
Generates a strategic plan for the given power based on the current state.
This method is called by the agent's generate_plan method.
"""
logger.info(f"Client generating strategic plan for {power_name}...")
-
+
planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
if not planning_instructions:
logger.error("Could not load planning_instructions.txt! Cannot generate plan.")
@@ -740,17 +710,17 @@ class BaseModelClient:
# For planning, possible_orders might be less critical for the context,
# but build_context_prompt expects it. We can pass an empty dict or calculate it.
# For simplicity, let's pass empty if not strictly needed by context for planning.
- possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context
-
+ possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context
+
context_prompt = self.build_context_prompt(
game,
board_state,
power_name,
- possible_orders_for_context,
+ possible_orders_for_context,
game_history,
agent_goals=agent_goals,
agent_relationships=agent_relationships,
- agent_private_diary=agent_private_diary_str, # Pass diary string
+ agent_private_diary=agent_private_diary_str, # Pass diary string
prompts_dir=self.prompts_dir,
)
@@ -765,12 +735,12 @@ class BaseModelClient:
try:
# Use run_llm_and_log for the actual LLM call
raw_plan_response = await run_llm_and_log(
- client=self, # Pass self (the client instance)
+ client=self, # Pass self (the client instance)
prompt=full_prompt,
log_file_path=log_file_path,
power_name=power_name,
- phase=game.current_short_phase,
- response_type='plan_generation', # More specific type for run_llm_and_log context
+ phase=game.current_short_phase,
+ response_type="plan_generation", # More specific type for run_llm_and_log context
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
# No parsing needed for the plan, return the raw string
@@ -781,16 +751,16 @@ class BaseModelClient:
success_status = f"Failure: Exception ({type(e).__name__})"
plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}"
finally:
- if log_file_path: # Only log if a path is provided
+ if log_file_path: # Only log if a path is provided
log_llm_response(
log_file_path=log_file_path,
model_name=self.model_name,
power_name=power_name,
phase=game.current_short_phase if game else "UnknownPhase",
- response_type="plan_generation", # Specific type for CSV logging
- raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
+ response_type="plan_generation", # Specific type for CSV logging
+ raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
raw_response=raw_plan_response,
- success=success_status
+ success=success_status,
# token_usage and cost can be added later
)
return plan_to_return
@@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
):
super().__init__(model_name, prompts_dir=prompts_dir)
- self.base_url = (
- base_url
- or os.environ.get("OPENAI_BASE_URL")
- or "https://api.openai.com/v1"
- )
+ self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
@@ -832,18 +798,14 @@ class OpenAIClient(BaseModelClient):
inject_random_seed: bool = True,
) -> str:
try:
- system_prompt_content = (
- f"{generate_random_seed()}\n\n{self.system_prompt}"
- if inject_random_seed
- else self.system_prompt
- )
+ system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
response = await self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt_content},
- {"role": "user", "content": prompt_with_cta},
+ {"role": "user", "content": prompt_with_cta},
],
temperature=temperature,
max_tokens=self.max_tokens,
@@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
temperature=temperature,
)
if not response.content:
- logger.warning(
- f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
- )
+ logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
return ""
return response.content[0].text.strip() if response.content else ""
except json.JSONDecodeError as json_err:
- logger.error(
- f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
- )
+ logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
return ""
except Exception as e:
- logger.error(
- f"[{self.model_name}] Unexpected error in generate_response: {e}"
- )
+ logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
try:
- generation_config = genai.types.GenerationConfig(
- temperature=temperature,
- max_output_tokens=self.max_tokens
- )
+ generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
response = await self.client.generate_content_async(
contents=full_prompt,
generation_config=generation_config,
)
-
+
if not response or not response.text:
- logger.warning(
- f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
- )
+ logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
return ""
return response.text.strip()
except Exception as e:
@@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
- self.client = AsyncDeepSeekOpenAI(
- api_key=self.api_key,
- base_url="https://api.deepseek.com/"
- )
+ self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
try:
@@ -982,13 +930,11 @@ class DeepSeekClient(BaseModelClient):
temperature=temperature,
max_tokens=self.max_tokens,
)
-
+
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
if not response or not response.choices:
- logger.warning(
- f"[{self.model_name}] No valid response in generate_response."
- )
+ logger.warning(f"[{self.model_name}] No valid response in generate_response.")
return ""
content = response.choices[0].message.content.strip()
@@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
return content
except Exception as e:
- logger.error(
- f"[{self.model_name}] Unexpected error in generate_response: {e}"
- )
+ logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@@ -1028,7 +972,7 @@ class OpenAIResponsesClient(BaseModelClient):
system_prompt_content = f"{random_seed}\n\n{self.system_prompt}"
full_prompt = f"{system_prompt_content}\n\n{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
-
+
# Prepare the request payload
payload = {
"model": self.model_name,
@@ -1036,79 +980,60 @@ class OpenAIResponsesClient(BaseModelClient):
"temperature": temperature,
"max_tokens": self.max_tokens,
}
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
-
+
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
+
# Make the API call using aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(self.base_url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
- logger.error(
- f"[{self.model_name}] API error (status {response.status}): {error_text}"
- )
+ logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
return ""
-
+
response_data = await response.json()
-
+
# Extract the text from the nested response structure
# The text is in output[1].content[0].text based on the response
try:
outputs = response_data.get("output", [])
if len(outputs) < 2:
- logger.warning(
- f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
- )
+ logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
return ""
-
+
# The message is typically in the second output item
message_output = outputs[1]
if message_output.get("type") != "message":
- logger.warning(
- f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
- )
+ logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
return ""
-
+
content_list = message_output.get("content", [])
if not content_list:
- logger.warning(
- f"[{self.model_name}] Empty content list in message output"
- )
+ logger.warning(f"[{self.model_name}] Empty content list in message output")
return ""
-
+
# Look for the content item with type 'output_text'
text_content = ""
for content_item in content_list:
if content_item.get("type") == "output_text":
text_content = content_item.get("text", "")
break
-
+
if not text_content:
- logger.warning(
- f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
- )
+ logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
return ""
-
+
return text_content.strip()
-
+
except (KeyError, IndexError, TypeError) as e:
- logger.error(
- f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
- )
+ logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
return ""
-
+
except aiohttp.ClientError as e:
- logger.error(
- f"[{self.model_name}] HTTP client error in generate_response: {e}"
- )
+ logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
return ""
except Exception as e:
- logger.error(
- f"[{self.model_name}] Unexpected error in generate_response: {e}"
- )
+ logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@@ -1123,17 +1048,14 @@ class OpenRouterClient(BaseModelClient):
model_name = f"openrouter/{model_name}"
if model_name.startswith("openrouter-"):
model_name = model_name.replace("openrouter-", "")
-
+
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("OPENROUTER_API_KEY")
if not self.api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
-
- self.client = AsyncOpenAI(
- base_url="https://openrouter.ai/api/v1",
- api_key=self.api_key
- )
-
+
+ self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
+
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
@@ -1150,26 +1072,23 @@ class OpenRouterClient(BaseModelClient):
# Prepare standard OpenAI-compatible request
response = await self.client.chat.completions.create(
model=self.model_name,
- messages=[
- {"role": "system", "content": system_prompt_content},
- {"role": "user", "content": prompt_with_cta}
- ],
+ messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
max_tokens=self.max_tokens,
temperature=temperature,
)
-
+
if not response.choices:
logger.warning(f"[{self.model_name}] OpenRouter returned no choices")
return ""
-
+
content = response.choices[0].message.content.strip()
if not content:
logger.warning(f"[{self.model_name}] OpenRouter returned empty content")
return ""
-
+
# Parse or return the raw content
return content
-
+
except Exception as e:
error_msg = str(e)
# Check if it's a specific OpenRouter error
@@ -1200,7 +1119,7 @@ class TogetherAIClient(BaseModelClient):
self.api_key = os.environ.get("TOGETHER_API_KEY")
if not self.api_key:
raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient")
-
+
# The model_name passed to super() is used for logging and identification.
# The actual model name for the API call is self.model_name (from super class).
self.client = AsyncTogether(api_key=self.api_key)
@@ -1211,7 +1130,7 @@ class TogetherAIClient(BaseModelClient):
Generates a response from the Together AI model.
"""
logger.debug(f"[{self.model_name}] Generating response with prompt (first 100 chars): {prompt[:100]}...")
-
+
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
@@ -1226,7 +1145,7 @@ class TogetherAIClient(BaseModelClient):
# Consider adding max_tokens, temperature, etc. as needed
# max_tokens=2048, # Example
)
-
+
if response.choices and response.choices[0].message and response.choices[0].message.content is not None:
content = response.choices[0].message.content
logger.debug(f"[{self.model_name}] Received response (first 100 chars): {content[:100]}...")
@@ -1236,16 +1155,16 @@ class TogetherAIClient(BaseModelClient):
return ""
except TogetherAPIError as e:
logger.error(f"[{self.model_name}] Together AI API error: {e}", exc_info=True)
- return f"Error: Together AI API error - {str(e)}" # Return a string with error info
+ return f"Error: Together AI API error - {str(e)}" # Return a string with error info
except Exception as e:
logger.error(f"[{self.model_name}] Unexpected error in TogetherAIClient: {e}", exc_info=True)
- return f"Error: Unexpected error - {str(e)}" # Return a string with error info
+ return f"Error: Unexpected error - {str(e)}" # Return a string with error info
##############################################################################
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
##############################################################################
-import requests, asyncio
+
class RequestsOpenAIClient(BaseModelClient):
"""
@@ -1266,18 +1185,14 @@ class RequestsOpenAIClient(BaseModelClient):
if not self.api_key:
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
- self.base_url = (
- base_url
- or os.environ.get("OPENAI_BASE_URL")
- or "https://api.openai.com/v1"
- ).rstrip("/")
+ self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
self.endpoint = f"{self.base_url}/chat/completions"
# ---------------- internal blocking helper ---------------- #
def _post_sync(self, payload: dict) -> dict:
headers = {
- "Content-Type": "application/json",
+ "Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
@@ -1291,17 +1206,13 @@ class RequestsOpenAIClient(BaseModelClient):
temperature: float = 0.0,
inject_random_seed: bool = True,
) -> str:
- system_prompt_content = (
- f"{generate_random_seed()}\n\n{self.system_prompt}"
- if inject_random_seed
- else self.system_prompt
- )
+ system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": system_prompt_content},
- {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
+ {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
],
"temperature": temperature,
"max_tokens": self.max_tokens,
@@ -1322,15 +1233,15 @@ class RequestsOpenAIClient(BaseModelClient):
return ""
-
##############################################################################
# 3) Factory to Load Model Client
##############################################################################
class ModelSpec(NamedTuple):
- prefix: Optional[str] # 'openai', 'requests', …
- model: str # 'gpt-4o'
- base: Optional[str] # 'https://proxy.foo'
- key: Optional[str] # 'sk-…' (may be None)
+ prefix: Optional[str] # 'openai', 'requests', …
+ model: str # 'gpt-4o'
+ base: Optional[str] # 'https://proxy.foo'
+ key: Optional[str] # 'sk-…' (may be None)
+
def _parse_model_spec(raw: str) -> ModelSpec:
"""
@@ -1339,11 +1250,11 @@ def _parse_model_spec(raw: str) -> ModelSpec:
"""
raw = raw.strip()
- pre_hash, _, key_part = raw.partition("#")
- pre_at, _, base_part = pre_hash.partition("@")
+ pre_hash, _, key_part = raw.partition("#")
+ pre_at, _, base_part = pre_hash.partition("@")
maybe_pref, sep, model_part = pre_at.partition(":")
- if sep: # explicit prefix was present
+ if sep: # explicit prefix was present
prefix, model = maybe_pref.lower(), model_part
else:
prefix, model = None, maybe_pref
@@ -1379,17 +1290,17 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
match spec.prefix:
case "openai" | "oai":
return OpenAIClient(
- model_name = spec.model,
- prompts_dir = prompts_dir,
- base_url = spec.base,
- api_key = inline_key,
+ model_name=spec.model,
+ prompts_dir=prompts_dir,
+ base_url=spec.base,
+ api_key=inline_key,
)
case "requests" | "req":
return RequestsOpenAIClient(
- model_name = spec.model,
- prompts_dir = prompts_dir,
- base_url = spec.base,
- api_key = inline_key,
+ model_name=spec.model,
+ prompts_dir=prompts_dir,
+ base_url=spec.base,
+ api_key=inline_key,
)
case "responses" | "oai-resp" | "openai-responses":
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
@@ -1409,7 +1320,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
# ------------------------------------------------------------------ #
# 2. Heuristic fallback path (identical to the original behaviour) #
# ------------------------------------------------------------------ #
- lower_id = spec.model.lower()
+ lower_id = spec.model.lower()
if lower_id == "o3-pro":
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
@@ -1432,17 +1343,15 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
# Default: OpenAI-compatible async client
return OpenAIClient(
- model_name = spec.model,
- prompts_dir = prompts_dir,
- base_url = spec.base,
- api_key = inline_key,
+ model_name=spec.model,
+ prompts_dir=prompts_dir,
+ base_url=spec.base,
+ api_key=inline_key,
)
return OpenAIClient(model_name, prompts_dir, base_url)
-
-
##############################################################################
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
##############################################################################
@@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
visible = []
for msg in conversation_messages:
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
- if (
- msg["recipient"] == "ALL"
- or msg["recipient"] == "GLOBAL"
- or msg["sender"] == power_name
- or msg["recipient"] == power_name
- ):
+ if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
visible.append(msg)
return visible # already in chronological order if appended that way
diff --git a/ai_diplomacy/diary_logic.py b/ai_diplomacy/diary_logic.py
index aa0be0f..1cfa164 100644
--- a/ai_diplomacy/diary_logic.py
+++ b/ai_diplomacy/diary_logic.py
@@ -11,25 +11,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
+
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
- default_prompts_dir = os.path.join(current_dir, 'prompts')
+ default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
- with open(filepath, 'r', encoding='utf-8') as f:
+ with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
+
async def run_diary_consolidation(
- agent: 'DiplomacyAgent',
+ agent: "DiplomacyAgent",
game: "Game",
log_file_path: str,
entries_to_keep_unsummarized: int = 6,
@@ -39,70 +42,42 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class.
"""
- logger.info(
- f"[{agent.power_name}] CONSOLIDATION START — "
- f"{len(agent.full_private_diary)} total full entries"
- )
+ logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
- full_entries = [
- e for e in agent.full_private_diary
- if not e.startswith("[CONSOLIDATED HISTORY]")
- ]
+ full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary)
- logger.info(
- f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
- "skipping consolidation"
- )
+ logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
return
boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match:
- logger.error(
- f"[{agent.power_name}] Could not parse year from boundary entry; "
- "aborting consolidation"
- )
+ logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
agent.private_diary = list(agent.full_private_diary)
return
cutoff_year = int(match.group(1))
- logger.info(
- f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
- )
+ logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
def _entry_year(entry: str) -> int | None:
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None
- entries_to_summarize = [
- e for e in full_entries
- if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
- ]
- entries_to_keep = [
- e for e in full_entries
- if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
- ]
+ entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
+ entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
- logger.info(
- f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
- f"keeping {len(entries_to_keep)} recent entries verbatim"
- )
+ logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary)
- logger.warning(
- f"[{agent.power_name}] No eligible entries to summarise; "
- "context diary left unchanged"
- )
+ logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template:
- logger.error(
- f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
- )
+ logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
return
prompt = prompt_template.format(
@@ -132,27 +107,18 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE"
- logger.info(
- f"[{agent.power_name}] Consolidation complete — "
- f"{len(agent.private_diary)} context entries now"
- )
+ logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
except Exception as exc:
- logger.error(
- f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
- )
+ logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
finally:
log_llm_response(
log_file_path=log_file_path,
- model_name=(
- consolidation_client.model_name
- if consolidation_client is not None
- else agent.client.model_name
- ),
+ model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
power_name=agent.power_name,
phase=game.current_short_phase,
response_type="diary_consolidation",
raw_input_prompt=prompt,
raw_response=raw_response,
success=success_flag,
- )
\ No newline at end of file
+ )
diff --git a/ai_diplomacy/formatter.py b/ai_diplomacy/formatter.py
index d61d5c4..20edfec 100644
--- a/ai_diplomacy/formatter.py
+++ b/ai_diplomacy/formatter.py
@@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
"""
-import json
import logging
-import os
-from typing import Dict, Any, Optional
+from typing import Optional
from pathlib import Path
# Import logging function and model configuration
from .utils import log_llm_response, get_special_models
+
# Import client loading function
from .clients import load_model_client
@@ -26,77 +25,73 @@ FORMAT_ORDER_DIARY = "order_diary"
async def format_with_gemini_flash(
- raw_response: str,
- format_type: str,
- power_name: Optional[str] = None,
- phase: Optional[str] = None,
- log_file_path: Optional[str] = None
+ raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
) -> str:
"""
Format a natural language response into required JSON structure using Gemini Flash.
-
+
Args:
raw_response: Natural language response from primary LLM
format_type: Type of formatting required (e.g., FORMAT_ORDERS, FORMAT_STATE_UPDATE)
power_name: Optional power name for logging
phase: Optional phase for logging
log_file_path: Optional path for CSV logging
-
+
Returns:
JSON string in the expected format
"""
# Map format types to prompt files
format_prompts = {
FORMAT_STATE_UPDATE: "formatting/format_state_update.txt",
- FORMAT_CONVERSATION: "formatting/format_conversation.txt",
+ FORMAT_CONVERSATION: "formatting/format_conversation.txt",
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
FORMAT_ORDERS: "formatting/format_orders.txt",
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
- FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
+ FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
}
-
+
if format_type not in format_prompts:
raise ValueError(f"Unknown format type: {format_type}")
-
+
# Load the formatting prompt
prompt_file = Path(__file__).parent / "prompts" / format_prompts[format_type]
if not prompt_file.exists():
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
-
- with open(prompt_file, 'r') as f:
+
+ with open(prompt_file, "r") as f:
format_prompt = f.read()
-
+
# Replace placeholder with actual response
format_prompt = format_prompt.replace("[RAW_RESPONSE]", raw_response)
-
+
# Get model name from configuration
special_models = get_special_models()
model_name = special_models["formatter"]
-
+
# Load the formatter client using the same logic as other models
formatter_client = load_model_client(model_name)
-
+
try:
logger.info(f"[FORMATTER] Calling {model_name} for {format_type} formatting")
-
+
# Create the full prompt with system message
system_content = "You are a precise formatting assistant. Extract and format information exactly as requested."
formatter_client.set_system_prompt(system_content)
-
+
# Use the client's generate_response method
formatted_response = await formatter_client.generate_response(
prompt=format_prompt,
temperature=0, # Deterministic formatting
- inject_random_seed=False # No need for random seed in formatting
+ inject_random_seed=False, # No need for random seed in formatting
)
-
+
if not formatted_response:
logger.warning(f"[FORMATTER] {model_name} returned empty response")
return ""
-
+
# Log successful formatting
logger.info(f"[FORMATTER] Successfully formatted {format_type} response")
-
+
# Strip any markdown formatting that Gemini might add
if formatted_response.startswith("```json"):
formatted_response = formatted_response[7:]
@@ -105,7 +100,7 @@ async def format_with_gemini_flash(
if formatted_response.endswith("```"):
formatted_response = formatted_response[:-3]
formatted_response = formatted_response.strip()
-
+
# Log if requested
if log_file_path:
log_llm_response(
@@ -116,14 +111,14 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=formatted_response,
- success="Success"
+ success="Success",
)
-
+
return formatted_response
-
+
except Exception as e:
logger.error(f"[FORMATTER] Error calling Gemini Flash: {e}")
-
+
# Log error if requested
if log_file_path:
log_llm_response(
@@ -134,11 +129,11 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=f"ERROR: {str(e)}",
- success=f"Failure: {type(e).__name__}"
+ success=f"Failure: {type(e).__name__}",
)
-
+
# Return empty structure based on format type
if format_type == FORMAT_CONVERSATION:
return "[]"
else:
- return "{}"
\ No newline at end of file
+ return "{}"
diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py
index 27abd23..c0bf78e 100644
--- a/ai_diplomacy/game_history.py
+++ b/ai_diplomacy/game_history.py
@@ -3,7 +3,6 @@ import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional
-from pprint import pformat
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
@@ -24,12 +23,8 @@ class Phase:
plans: Dict[str, str] = field(default_factory=dict)
messages: List[Message] = field(default_factory=list)
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
- orders_by_power: Dict[str, List[str]] = field(
- default_factory=lambda: defaultdict(list)
- )
- results_by_power: Dict[str, List[List[str]]] = field(
- default_factory=lambda: defaultdict(list)
- )
+ orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
+ results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
# NEW: Store phase-end summaries provided by each power
phase_summaries: Dict[str, str] = field(default_factory=dict)
# NEW: Store experience/journal updates from each power for this phase
@@ -37,11 +32,9 @@ class Phase:
def add_plan(self, power_name: str, plan: str):
self.plans[power_name] = plan
-
+
def add_message(self, sender: str, recipient: str, content: str):
- self.messages.append(
- Message(sender=sender, recipient=recipient, content=content)
- )
+ self.messages.append(Message(sender=sender, recipient=recipient, content=content))
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
self.orders_by_power[power].extend(orders)
@@ -110,14 +103,10 @@ class GameHistory:
phase.plans[power_name] = plan
logger.debug(f"Added plan for {power_name} in {phase_name}")
- def add_message(
- self, phase_name: str, sender: str, recipient: str, message_content: str
- ):
+ def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
phase = self._get_phase(phase_name)
if phase:
- message = Message(
- sender=sender, recipient=recipient, content=message_content
- )
+ message = Message(sender=sender, recipient=recipient, content=message_content)
phase.messages.append(message)
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
@@ -147,16 +136,12 @@ class GameHistory:
phase.experience_updates[power_name] = update
logger.debug(f"Added experience update for {power_name} in {phase_name}")
- def get_strategic_directives(self):
+ def get_strategic_directives(self):
# returns for last phase only if exists
- if not self.phases:
+ if not self.phases:
return {}
return self.phases[-1].plans
-
-
-
-
def get_order_history_for_prompt(
self,
game: "Game",
@@ -194,10 +179,7 @@ class GameHistory:
return str(tag).strip().lower()
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
- eng2code = {
- "AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
- "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
- }
+ eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
norm = game.map.norm
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
@@ -209,14 +191,14 @@ class GameHistory:
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
submitted = ph.submitted_orders_by_power.get(pwr, [])
- accepted = ph.orders_by_power.get(pwr, [])
+ accepted = ph.orders_by_power.get(pwr, [])
if isinstance(submitted, str):
submitted = [submitted]
if isinstance(accepted, str):
accepted = [accepted]
- def _norm_keep(o): # keep WAIVE readable
+ def _norm_keep(o): # keep WAIVE readable
return o if o.upper() == "WAIVE" else norm(o)
sub_norm = {_norm_keep(o): o for o in submitted}
@@ -264,13 +246,6 @@ class GameHistory:
return "\n(No orders were issued in recent history)\n"
return "\n".join(out_lines)
-
-
-
-
-
-
-
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None
for phase_obj in self.phases:
@@ -281,7 +256,7 @@ class GameHistory:
if not current_phase:
return f"\n(No messages found for current phase: {current_phase_name})\n"
- messages_str = ""
+ messages_str = ""
global_msgs_content = current_phase.get_global_messages()
if global_msgs_content:
@@ -299,7 +274,7 @@ class GameHistory:
messages_str += "\n"
else:
messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n (No private messages this round)\n"
-
+
if not global_msgs_content and not private_msgs_dict:
return f"\n(No messages recorded for current phase: {current_phase_name})\n"
@@ -313,10 +288,10 @@ class GameHistory:
"""
if not self.phases:
return []
-
+
# Get the most recent 2 phases including current phase
recent_phases = self.phases[-2:] if len(self.phases) >= 2 else self.phases[-1:]
-
+
# Collect all messages sent TO this power
messages_to_power = []
for phase in recent_phases:
@@ -325,48 +300,44 @@ class GameHistory:
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
# Skip if sender is this power (don't need to respond to own messages)
if msg.sender != power_name:
- messages_to_power.append({
- 'sender': msg.sender,
- 'content': msg.content,
- 'phase': phase.name
- })
-
+ messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
+
# Add debug logging
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
if not messages_to_power:
logger.info(f"No messages found for {power_name} to respond to")
-
+
# Take the most recent 'limit' messages
return messages_to_power[-limit:] if messages_to_power else []
-
+
def get_ignored_messages_by_power(self, sender_name: str, num_phases: int = 3) -> Dict[str, List[Dict[str, str]]]:
"""
Identifies which powers are not responding to messages from sender_name.
Returns a dict mapping power names to their ignored messages.
-
+
A message is considered ignored if:
1. It was sent from sender_name to another power (private)
2. No response from that power was received in the same or next phase
"""
ignored_by_power = {}
-
+
# Get recent phases
recent_phases = self.phases[-num_phases:] if self.phases else []
if not recent_phases:
return ignored_by_power
-
+
for i, phase in enumerate(recent_phases):
# Get messages sent by sender to specific powers (not global)
sender_messages = []
for msg in phase.messages:
# Handle both Message objects and dict objects
if isinstance(msg, Message):
- if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
+ if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
else: # Assume dict
- if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
+ if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
-
+
# Check for responses in this and next phases
for msg in sender_messages:
# Handle both Message objects and dict objects
@@ -374,37 +345,36 @@ class GameHistory:
recipient = msg.recipient
msg_content = msg.content
else:
- recipient = msg['recipient']
- msg_content = msg['content']
-
+ recipient = msg["recipient"]
+ msg_content = msg["content"]
+
# Look for responses in current phase and next phases
found_response = False
-
+
# Check remaining phases starting from current
- for check_phase in recent_phases[i:min(i+2, len(recent_phases))]:
+ for check_phase in recent_phases[i : min(i + 2, len(recent_phases))]:
# Look for messages FROM the recipient TO the sender (direct response)
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
response_msgs = []
for m in check_phase.messages:
if isinstance(m, Message):
- if m.sender == recipient and (m.recipient == sender_name or
- (m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
+ if m.sender == recipient and (
+ m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
+ ):
response_msgs.append(m)
else: # Assume dict
- if m['sender'] == recipient and (m['recipient'] == sender_name or
- (m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
+ if m["sender"] == recipient and (
+ m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
+ ):
response_msgs.append(m)
-
+
if response_msgs:
found_response = True
break
-
+
if not found_response:
if recipient not in ignored_by_power:
ignored_by_power[recipient] = []
- ignored_by_power[recipient].append({
- 'phase': phase.name,
- 'content': msg_content
- })
-
- return ignored_by_power
\ No newline at end of file
+ ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
+
+ return ignored_by_power
diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py
index 855c426..8a44871 100644
--- a/ai_diplomacy/game_logic.py
+++ b/ai_diplomacy/game_logic.py
@@ -3,7 +3,7 @@ import logging
import os
import json
import asyncio
-from typing import Dict, List, Tuple, Optional, Any
+from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from diplomacy import Game
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
# --- Serialization / Deserialization ---
+
def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary."""
return {
@@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary,
}
-def deserialize_agent(
- agent_data: dict,
- prompts_dir: Optional[str] = None,
- *,
- override_model_id: Optional[str] = None
-) -> DiplomacyAgent:
+
+def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
@@ -68,34 +65,31 @@ def deserialize_agent(
# --- State Management ---
# game_logic.py
-_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
+_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
+
def _next_phase_name(short: str) -> str:
"""
Return the Diplomacy phase string that chronologically follows *short*.
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
"""
- season = short[0] # 'S' | 'W'
- year = int(short[1:5])
- typ = short[-1] # 'M' | 'R' | 'A'
+ season = short[0] # 'S' | 'W'
+ year = int(short[1:5])
+ typ = short[-1] # 'M' | 'R' | 'A'
idx = _PHASE_ORDER.index(typ)
- if idx < 2: # still in the same season
- return f"{season}{year}{_PHASE_ORDER[idx+1]}"
+ if idx < 2: # still in the same season
+ return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
# typ was 'A' → roll season
- if season == "S": # summer → winter, same year
+ if season == "S": # summer → winter, same year
return f"W{year}M"
- else: # winter→ spring, next year
- return f"S{year+1}M"
+ else: # winter→ spring, next year
+ return f"S{year + 1}M"
+
def save_game_state(
- game: Game,
- agents: Dict[str, DiplomacyAgent],
- game_history: GameHistory,
- output_path: str,
- run_config: Namespace,
- completed_phase_name: str
+ game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata
@@ -130,9 +124,7 @@ def save_game_state(
if extras:
previous_phase_extras[phase["name"]] = extras
except Exception as exc:
- logger.warning(
- "Could not load previous save to retain metadata: %s", exc, exc_info=True
- )
+ logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. #
@@ -143,16 +135,12 @@ def save_game_state(
# 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once.
- current_state_agents = {
- p_name: serialize_agent(p_agent)
- for p_name, p_agent in agents.items()
- if not game.powers[p_name].is_eliminated()
- }
+ current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
for phase_block in saved_game.get("phases", []):
if int(phase_block["name"][1:5]) > run_config.max_year:
break
-
+
phase_name = phase_block["name"]
# 3a. Re-attach anything we cached from a previous save.
@@ -167,31 +155,24 @@ def save_game_state(
# Plans for this phase – may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name)
- phase_block["state_history_plans"] = (
- phase_obj.plans if phase_obj else {}
- )
-
+ phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
- saved_game["final_agent_states"] = {
- p_name: {"relationships": a.relationships, "goals": a.goals}
- for p_name, a in agents.items()
- }
+ saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
# Filter out phases > max_year
- #saved_game["phases"] = [
+ # saved_game["phases"] = [
# ph for ph in saved_game["phases"]
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
- #]
+ # ]
atomic_write_json(saved_game, output_path)
logger.info("Game state saved successfully.")
-
def load_game_state(
run_dir: str,
game_file_name: str,
@@ -204,7 +185,7 @@ def load_game_state(
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}")
- with open(game_file_path, 'r') as f:
+ with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it
@@ -213,49 +194,49 @@ def load_game_state(
try:
# Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase.
- resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
+ resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase.
- saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
+ saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"):
- saved_game_data['phases'][-1].pop(key, None)
+ saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M":
- saved_game_data['phases'] = []
- logger.info("Resuming from S1901M. Starting with a clean history.")
+ saved_game_data["phases"] = []
+ logger.info("Resuming from S1901M. Starting with a clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object
- last_phase = saved_game_data['phases'][-1]
+ last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys**
- last_phase['orders'] = {} # was dict
- last_phase['results'] = {} # was dict
- last_phase['messages'] = []
+ last_phase["orders"] = {} # was dict
+ last_phase["results"] = {} # was dict
+ last_phase["messages"] = []
game = from_saved_game_format(saved_game_data)
- game.phase_summaries = saved_game_data.get('phase_summaries', {})
+ game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data
- if not saved_game_data['phases']:
+ if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.")
- agents = {} # Will be created by the main loop
+ agents = {} # Will be created by the main loop
game_history = GameHistory()
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
- if len(saved_game_data['phases']) <= 1:
+ if len(saved_game_data["phases"]) <= 1:
last_phase_data = {}
else:
- last_phase_data = saved_game_data['phases'][-2]
-
+ last_phase_data = saved_game_data["phases"][-2]
+
# -------------------- Rebuild agents -------------------- #
agents = {}
@@ -269,18 +250,11 @@ def load_game_state(
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
- raise ValueError(
- f"Invalid --models argument: expected 1 or {len(powers_order)} "
- f"items, got {len(provided)}."
- )
+ raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
- prompts_dir_from_config = (
- run_config.prompts_dir
- if run_config and hasattr(run_config, "prompts_dir")
- else None
- )
+ prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
agents[power_name] = deserialize_agent(
@@ -290,40 +264,30 @@ def load_game_state(
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
- raise ValueError(
- "Cannot resume: 'state_agents' key not found in the last phase "
- "of the save file."
- )
-
+ raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
# Rebuild GameHistory
game_history = GameHistory()
logger.info("Rebuilding game history...")
- for phase_data in saved_game_data['phases'][:-1]:
- phase_name = phase_data['name']
+ for phase_data in saved_game_data["phases"][:-1]:
+ phase_name = phase_data["name"]
game_history.add_phase(phase_name)
# Add messages
- for msg in phase_data.get('messages', []):
- game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
+ for msg in phase_data.get("messages", []):
+ game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans
- if 'state_history_plans' in phase_data:
- for p_name, plan in phase_data['state_history_plans'].items():
+ if "state_history_plans" in phase_data:
+ for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.")
-
return game, agents, game_history, run_config
-async def initialize_new_game(
- args: Namespace,
- game: Game,
- game_history: GameHistory,
- llm_log_file_path: str
-) -> Dict[str, DiplomacyAgent]:
+async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game."""
powers_order = sorted(list(ALL_POWERS))
-
+
# Parse token limits
default_max_tokens = args.max_tokens
model_max_tokens = {p: default_max_tokens for p in powers_order}
@@ -339,7 +303,7 @@ async def initialize_new_game(
# Handle power model mapping
if args.models:
- provided_models = [name.strip() for name in args.models.split(",")]
+ provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) == len(powers_order):
game.power_model_map = dict(zip(powers_order, provided_models))
elif len(provided_models) == 1:
@@ -364,17 +328,17 @@ async def initialize_new_game(
initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir))
except Exception as e:
logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True)
-
+
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
-
+
initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results):
- if i < len(initialized_powers):
- power_name = initialized_powers[i]
- if isinstance(result, Exception):
- logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
- else:
- logger.info(f"Successfully initialized agent state for {power_name}.")
-
- return agents
\ No newline at end of file
+ if i < len(initialized_powers):
+ power_name = initialized_powers[i]
+ if isinstance(result, Exception):
+ logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
+ else:
+ logger.info(f"Successfully initialized agent state for {power_name}.")
+
+ return agents
diff --git a/ai_diplomacy/initialization.py b/ai_diplomacy/initialization.py
index 5f460cb..32b6be1 100644
--- a/ai_diplomacy/initialization.py
+++ b/ai_diplomacy/initialization.py
@@ -1,12 +1,11 @@
# ai_diplomacy/initialization.py
import logging
import json
-import os
from typing import Optional
from ..config import config
# Forward declaration for type hinting, actual imports in function if complex
-if False: # TYPE_CHECKING
+if False: # TYPE_CHECKING
from diplomacy import Game
from diplomacy.models.game import GameHistory
from .agent import DiplomacyAgent
@@ -18,40 +17,40 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
logger = logging.getLogger(__name__)
+
async def initialize_agent_state_ext(
- agent: 'DiplomacyAgent',
- game: 'Game',
- game_history: 'GameHistory',
+ agent: "DiplomacyAgent",
+ game: "Game",
+ game_history: "GameHistory",
log_file_path: str,
prompts_dir: Optional[str] = None,
):
"""Uses the LLM to set initial goals and relationships for the agent."""
power_name = agent.power_name
- logger.info(f"[{power_name}] Initializing agent state using LLM (external function)..." )
+ logger.info(f"[{power_name}] Initializing agent state using LLM (external function)...")
current_phase = game.get_current_phase() if game else "UnknownPhase"
full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block
- response = "" # Ensure response is defined for finally block
- success_status = "Failure: Initialized" # Default status
+ response = "" # Ensure response is defined for finally block
+ success_status = "Failure: Initialized" # Default status
try:
# Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
- initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
+ initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return
-
+
# Format the prompt with variables
- initial_prompt = initial_prompt_template.format(
- power_name=power_name,
- allowed_labels_str=allowed_labels_str
- )
+ initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
board_state = game.get_state() if game else {}
possible_orders = game.get_all_possible_orders() if game else {}
- logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
+ logger.debug(
+ f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
+ )
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
# For initialization, game should always be present.
@@ -59,12 +58,12 @@ async def initialize_agent_state_ext(
context = build_context_prompt(
game=game,
- board_state=board_state,
+ board_state=board_state,
power_name=power_name,
- possible_orders=possible_orders,
- game_history=game_history,
- agent_goals=None,
- agent_relationships=None,
+ possible_orders=possible_orders,
+ game_history=game_history,
+ agent_goals=None,
+ agent_relationships=None,
agent_private_diary=formatted_diary,
prompts_dir=prompts_dir,
)
@@ -76,9 +75,9 @@ async def initialize_agent_state_ext(
log_file_path=log_file_path,
power_name=power_name,
phase=current_phase,
- response_type='initialization', # Context for run_llm_and_log internal error logging
+ response_type="initialization", # Context for run_llm_and_log internal error logging
)
- logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
+ logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
parsed_successfully = False
try:
@@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
if config.USE_UNFORMATTED_PROMPTS:
# Format the natural language response into JSON
formatted_response = await format_with_gemini_flash(
- response,
- FORMAT_INITIAL_STATE,
- power_name=power_name,
- phase=current_phase,
- log_file_path=log_file_path
+ response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
@@ -101,20 +96,24 @@ async def initialize_agent_state_ext(
except json.JSONDecodeError as e:
logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...")
success_status = "Failure: JSONDecodeError"
- update_data = {} # Ensure update_data exists for fallback logic below
- parsed_successfully = False # Explicitly set here too
+ update_data = {} # Ensure update_data exists for fallback logic below
+ parsed_successfully = False # Explicitly set here too
# Fallback logic for goals/relationships will be handled later if update_data is empty
# Defensive check for update_data type if parsing was initially considered successful
- if parsed_successfully:
+ if parsed_successfully:
if isinstance(update_data, str):
- logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
- update_data = {} # Treat as parsing failure
+ logger.error(
+ f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
+ )
+ update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: ParsedAsStr"
- elif not isinstance(update_data, dict): # Expecting a dict from JSON object
- logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
- update_data = {} # Treat as parsing failure
+ elif not isinstance(update_data, dict): # Expecting a dict from JSON object
+ logger.error(
+ f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
+ )
+ update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: NotADict"
@@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
initial_relationships_applied = False
if parsed_successfully:
- initial_goals = update_data.get('initial_goals') or update_data.get('goals')
- initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
+ initial_goals = update_data.get("initial_goals") or update_data.get("goals")
+ initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
if isinstance(initial_goals, list) and initial_goals:
agent.goals = initial_goals
@@ -152,30 +151,30 @@ async def initialize_agent_state_ext(
else:
logger.warning(f"[{power_name}] No valid relationships found in LLM response.")
else:
- logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
-
+ logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
+
if initial_goals_applied or initial_relationships_applied:
success_status = "Success: Applied LLM data"
- elif parsed_successfully: # Parsed but nothing useful to apply
+ elif parsed_successfully: # Parsed but nothing useful to apply
success_status = "Success: Parsed but no data applied"
# If not parsed_successfully, success_status is already "Failure: JSONDecodeError"
# Fallback if LLM data was not applied or parsing failed
if not initial_goals_applied:
- if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
+ if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"]
agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.")
logger.info(f"[{power_name}] Default goals set.")
-
+
if not initial_relationships_applied:
- # Check if relationships are still default-like before overriding
+ # Check if relationships are still default-like before overriding
is_default_relationships = True
- if agent.relationships: # Check if it's not empty
+ if agent.relationships: # Check if it's not empty
for p in ALL_POWERS:
if p != power_name and agent.relationships.get(p) != "Neutral":
is_default_relationships = False
break
- if is_default_relationships:
+ if is_default_relationships:
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
agent.add_journal_entry(f"[{current_phase}] Set default neutral relationships as LLM provided none valid or parse failed.")
logger.info(f"[{power_name}] Default neutral relationships set.")
@@ -191,16 +190,16 @@ async def initialize_agent_state_ext(
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}")
finally:
- if log_file_path: # Ensure log_file_path is provided
+ if log_file_path: # Ensure log_file_path is provided
log_llm_response(
log_file_path=log_file_path,
model_name=agent.client.model_name if agent and agent.client else "UnknownModel",
power_name=power_name,
phase=current_phase,
- response_type="initial_state_setup", # Specific type for CSV logging
+ response_type="initial_state_setup", # Specific type for CSV logging
raw_input_prompt=full_prompt,
raw_response=response,
- success=success_status
+ success=success_status,
)
# Final log of state after initialization attempt
diff --git a/ai_diplomacy/narrative.py b/ai_diplomacy/narrative.py
index ef7f087..831d094 100644
--- a/ai_diplomacy/narrative.py
+++ b/ai_diplomacy/narrative.py
@@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
-
diff --git a/ai_diplomacy/negotiations.py b/ai_diplomacy/negotiations.py
index 4c471aa..3259337 100644
--- a/ai_diplomacy/negotiations.py
+++ b/ai_diplomacy/negotiations.py
@@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
from diplomacy.engine.message import Message, GLOBAL
from .agent import DiplomacyAgent
-from .clients import load_model_client
-from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
+from .utils import gather_possible_orders, normalize_recipient_name
if TYPE_CHECKING:
from .game_history import GameHistory
@@ -21,9 +20,9 @@ load_dotenv()
async def conduct_negotiations(
- game: 'Game',
+ game: "Game",
agents: Dict[str, DiplomacyAgent],
- game_history: 'GameHistory',
+ game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
max_rounds: int = 3,
@@ -35,13 +34,9 @@ async def conduct_negotiations(
"""
logger.info("Starting negotiation phase.")
- active_powers = [
- p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
- ]
- eliminated_powers = [
- p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
- ]
-
+ active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
+ eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
+
logger.info(f"Active powers for negotiations: {active_powers}")
if eliminated_powers:
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
@@ -51,7 +46,7 @@ async def conduct_negotiations(
# We do up to 'max_rounds' single-message turns for each power
for round_index in range(max_rounds):
logger.info(f"Negotiation Round {round_index + 1}/{max_rounds}")
-
+
# Prepare tasks for asyncio.gather
tasks = []
power_names_for_tasks = []
@@ -99,28 +94,28 @@ async def conduct_negotiations(
# Process results
for i, result in enumerate(results):
power_name = power_names_for_tasks[i]
- agent = agents[power_name] # Get agent again for journaling
- model_name = agent.client.model_name # Get model name for stats
+ agent = agents[power_name] # Get agent again for journaling
+ model_name = agent.client.model_name # Get model name for stats
if isinstance(result, Exception):
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
# Use model_name for stats key if possible
if model_name in model_error_stats:
- model_error_stats[model_name]["conversation_errors"] += 1
- else: # Fallback to power_name if model name not tracked (shouldn't happen)
- model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
- model_error_stats[power_name]["conversation_errors"] += 1
- messages = [] # Treat as no messages on error
- elif result is None: # Handle case where client might return None on internal error
- logger.warning(f"Received None instead of messages for {power_name}.")
- messages = []
- if model_name in model_error_stats:
- model_error_stats[model_name]["conversation_errors"] += 1
- else:
- model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
- model_error_stats[power_name]["conversation_errors"] += 1
+ model_error_stats[model_name]["conversation_errors"] += 1
+ else: # Fallback to power_name if model name not tracked (shouldn't happen)
+ model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
+ model_error_stats[power_name]["conversation_errors"] += 1
+ messages = [] # Treat as no messages on error
+ elif result is None: # Handle case where client might return None on internal error
+ logger.warning(f"Received None instead of messages for {power_name}.")
+ messages = []
+ if model_name in model_error_stats:
+ model_error_stats[model_name]["conversation_errors"] += 1
+ else:
+ model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
+ model_error_stats[power_name]["conversation_errors"] += 1
else:
- messages = result # result is the list of message dicts
+ messages = result # result is the list of message dicts
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
# Process the received messages (same logic as before)
@@ -134,27 +129,27 @@ async def conduct_negotiations(
# Create an official message in the Diplomacy engine
# Determine recipient based on message type
if message.get("message_type") == "private":
- recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
+ recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
if recipient not in game.powers and recipient != GLOBAL:
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
- recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
- else: # Assume global if not private or type is missing
+ recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
+ else: # Assume global if not private or type is missing
recipient = GLOBAL
-
+
diplo_message = Message(
phase=game.current_short_phase,
sender=power_name,
- recipient=recipient, # Use determined recipient
- message=message.get("content", ""), # Use .get for safety
- time_sent=None, # Let the engine assign time
+ recipient=recipient, # Use determined recipient
+ message=message.get("content", ""), # Use .get for safety
+ time_sent=None, # Let the engine assign time
)
game.add_message(diplo_message)
# Also add to our custom history
game_history.add_message(
game.current_short_phase,
power_name,
- recipient, # Use determined recipient here too
- message.get("content", ""), # Use .get for safety
+ recipient, # Use determined recipient here too
+ message.get("content", ""), # Use .get for safety
)
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")
diff --git a/ai_diplomacy/planning.py b/ai_diplomacy/planning.py
index 0d6b8a9..5531e5a 100644
--- a/ai_diplomacy/planning.py
+++ b/ai_diplomacy/planning.py
@@ -1,18 +1,17 @@
-from dotenv import load_dotenv
import logging
import concurrent.futures
from typing import Dict
-from .clients import load_model_client
from .game_history import GameHistory
from .agent import DiplomacyAgent
logger = logging.getLogger(__name__)
+
async def planning_phase(
- game,
- agents: Dict[str, DiplomacyAgent],
- game_history: GameHistory,
+ game,
+ agents: Dict[str, DiplomacyAgent],
+ game_history: GameHistory,
model_error_stats,
log_file_path: str,
):
@@ -20,24 +19,18 @@ async def planning_phase(
Lets each power generate a strategic plan using their DiplomacyAgent.
"""
logger.info(f"Starting planning phase for {game.current_short_phase}...")
- active_powers = [
- p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
- ]
- eliminated_powers = [
- p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
- ]
-
+ active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
+ eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
+
logger.info(f"Active powers for planning: {active_powers}")
if eliminated_powers:
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
else:
logger.info("No eliminated powers yet.")
-
+
board_state = game.get_state()
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=len(active_powers)
- ) as executor:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
futures = {}
for power_name in active_powers:
if power_name not in agents:
@@ -45,7 +38,7 @@ async def planning_phase(
continue
agent = agents[power_name]
client = agent.client
-
+
future = executor.submit(
client.get_plan,
game,
@@ -67,20 +60,18 @@ async def planning_phase(
try:
plan_result = future.result()
logger.info(f"Received planning result from {power_name}.")
-
+
if plan_result.startswith("Error:"):
- logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
- if power_name in model_error_stats:
- model_error_stats[power_name].setdefault('planning_generation_errors', 0)
- model_error_stats[power_name]['planning_generation_errors'] += 1
- else:
- model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
- model_error_stats[f'{power_name}_planning_generation_errors'] += 1
+ logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
+ if power_name in model_error_stats:
+ model_error_stats[power_name].setdefault("planning_generation_errors", 0)
+ model_error_stats[power_name]["planning_generation_errors"] += 1
+ else:
+ model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
+ model_error_stats[f"{power_name}_planning_generation_errors"] += 1
elif plan_result:
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
- game_history.add_plan(
- game.current_short_phase, power_name, plan_result
- )
+ game_history.add_plan(game.current_short_phase, power_name, plan_result)
logger.debug(f"Added plan for {power_name} to history.")
else:
logger.warning(f"Agent {power_name} returned an empty plan.")
@@ -88,11 +79,11 @@ async def planning_phase(
except Exception as e:
logger.error(f"Exception during planning result processing for {power_name}: {e}")
if power_name in model_error_stats:
- model_error_stats[power_name].setdefault('planning_execution_errors', 0)
- model_error_stats[power_name]['planning_execution_errors'] += 1
+ model_error_stats[power_name].setdefault("planning_execution_errors", 0)
+ model_error_stats[power_name]["planning_execution_errors"] += 1
else:
- model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
- model_error_stats[f'{power_name}_planning_execution_errors'] += 1
-
+ model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
+ model_error_stats[f"{power_name}_planning_execution_errors"] += 1
+
logger.info("Planning phase processing complete.")
- return game_history
\ No newline at end of file
+ return game_history
diff --git a/ai_diplomacy/possible_order_context.py b/ai_diplomacy/possible_order_context.py
index 39c0bb9..1c0b47e 100644
--- a/ai_diplomacy/possible_order_context.py
+++ b/ai_diplomacy/possible_order_context.py
@@ -8,12 +8,13 @@ import logging
import re
# Placeholder for actual map type from diplomacy.engine.map.Map
-# GameMap = Any
+# GameMap = Any
# Type hint for board_state dictionary from game.get_state()
# BoardState = Dict[str, Any]
logger = logging.getLogger(__name__)
+
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
"""
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
@@ -22,20 +23,18 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
"""
# ── collect all 3-letter province codes ───────────────────────────────
provs: Set[str] = {
- loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
+ loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
for loc in game_map.locs
if len(loc.split("/")[0]) == 3
}
- graph: Dict[str, Dict[str, List[str]]] = {
- p: {"ARMY": [], "FLEET": []} for p in provs
- }
+ graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
# ── helper: list every concrete variant of a province ─────────────────
def variants(code: str) -> List[str]:
lst = list(game_map.loc_coasts.get(code, []))
if code not in lst:
- lst.append(code) # ensure base node included
+ lst.append(code) # ensure base node included
return lst
# ── populate adjacency by brute-force queries to `abuts()` ────────────
@@ -55,48 +54,43 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
graph[src]["ARMY"].append(dest)
# FLEET — any src variant that can host a fleet is valid
- if any(
- game_map.abuts("F", sv, "-", dv)
- for sv in src_vers
- for dv in dest_vers
- ):
+ if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
graph[src]["FLEET"].append(dest)
# ── tidy up duplicates / order ---------------------------------------
for p in graph:
- graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
+ graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"]))
return graph
-
def bfs_shortest_path(
- graph: Dict[str, Dict[str, List[str]]],
- board_state: BoardState,
- game_map: GameMap, # Added game_map
- start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
- unit_type: str,
- is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
-) -> Optional[List[str]]: # Returns path of SHORT names
+ graph: Dict[str, Dict[str, List[str]]],
+ board_state: BoardState,
+ game_map: GameMap, # Added game_map
+ start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
+ unit_type: str,
+ is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
+) -> Optional[List[str]]: # Returns path of SHORT names
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
-
+
# Convert full start location to short province name
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
- if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
+ if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
start_loc_short = start_loc_short[:3]
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
# A simpler way for non-coastal full (like 'VIE') or already short:
- if '/' not in start_loc_full:
- start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
- else: # Has '/', e.g. 'STP/SC'
- start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
+ if "/" not in start_loc_full:
+ start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
+ else: # Has '/', e.g. 'STP/SC'
+ start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
if start_loc_short not in graph:
logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.")
return None
- queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
+ queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])])
visited_nodes: Set[str] = {start_loc_short}
while queue:
@@ -104,14 +98,14 @@ def bfs_shortest_path(
# is_target_func expects a short location name
if is_target_func(current_loc_short, board_state):
- return path # Path of short names
+ return path # Path of short names
# possible_neighbors are SHORT names from the graph
possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, [])
-
+
for next_loc_short in possible_neighbors_short:
if next_loc_short not in visited_nodes:
- if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
+ if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.")
continue
visited_nodes.add(next_loc_short)
@@ -119,11 +113,12 @@ def bfs_shortest_path(
queue.append((next_loc_short, new_path))
return None
+
# --- Helper functions for context generation ---
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
- for power, unit_list in board_state.get('units', {}).items():
- for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
+ for power, unit_list in board_state.get("units", {}).items():
+ for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
parts = unit_str.split(" ")
if len(parts) == 2:
unit_map_loc = parts[1]
@@ -131,27 +126,29 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
return f"{parts[0]} {location} ({power})"
return None
+
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
"""Returns the controlling power's name if the location is an SC, else None."""
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
if loc_province_name not in game_map.scs:
return None
- for power, sc_list in board_state.get('centers', {}).items():
+ for power, sc_list in board_state.get("centers", {}).items():
if loc_province_name in sc_list:
return power
- return None # Unowned SC
+ return None # Unowned SC
+
def get_shortest_path_to_friendly_unit(
- board_state: BoardState,
+ board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]],
- game_map: GameMap, # Added game_map
- power_name: str,
- start_unit_loc_full: str,
- start_unit_type: str
+ game_map: GameMap, # Added game_map
+ power_name: str,
+ start_unit_loc_full: str,
+ start_unit_type: str,
) -> Optional[Tuple[str, List[str]]]:
"""Finds the shortest path to any friendly unit of the same power."""
-
+
def is_target_friendly(loc_short: str, current_board_state: BoardState) -> bool:
# loc_short is a short province name. Need to check all its full locations.
full_locs_for_short = game_map.loc_coasts.get(loc_short, [loc_short])
@@ -162,7 +159,7 @@ def get_shortest_path_to_friendly_unit(
return False
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly)
- if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
+ if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
target_loc_short = path_short_names[-1]
# Find the actual friendly unit string at one of the full locations of target_loc_short
friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT"
@@ -177,41 +174,41 @@ def get_shortest_path_to_friendly_unit(
def get_nearest_enemy_units(
- board_state: BoardState,
+ board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]],
- game_map: GameMap, # Added game_map
- power_name: str,
- start_unit_loc_full: str,
- start_unit_type: str,
- n: int = 3
+ game_map: GameMap, # Added game_map
+ power_name: str,
+ start_unit_loc_full: str,
+ start_unit_type: str,
+ n: int = 3,
) -> List[Tuple[str, List[str]]]:
"""Finds up to N nearest enemy units, sorted by path length."""
- enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
-
- all_enemy_unit_locations_full: List[Tuple[str,str]] = [] # (loc_full, unit_str_full)
+ enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
+
+ all_enemy_unit_locations_full: List[Tuple[str, str]] = [] # (loc_full, unit_str_full)
# board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... }
for p_name, unit_list_for_power in board_state.get("units", {}).items():
- if p_name != power_name: # If it's an enemy power
- for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
+ if p_name != power_name: # If it's an enemy power
+ for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
parts = unit_repr_from_state.split(" ")
if len(parts) == 2:
# unit_type_char = parts[0] # 'A' or 'F'
- loc_full = parts[1] # 'PAR' or 'STP/SC'
-
+ loc_full = parts[1] # 'PAR' or 'STP/SC'
+
# Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)"
full_unit_str_with_power = get_unit_at_location(board_state, loc_full)
- if full_unit_str_with_power: # Should find the unit if iteration is correct
- all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
+ if full_unit_str_with_power: # Should find the unit if iteration is correct
+ all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
- if '/' in target_enemy_loc_short:
+ if "/" in target_enemy_loc_short:
target_enemy_loc_short = target_enemy_loc_short[:3]
- if '/' not in target_enemy_loc_full:
+ if "/" not in target_enemy_loc_full:
target_enemy_loc_short = target_enemy_loc_full[:3]
else:
target_enemy_loc_short = target_enemy_loc_full[:3]
-
+
def is_specific_enemy_loc(loc_short: str, current_board_state: BoardState) -> bool:
# Check if loc_short corresponds to target_enemy_loc_full
return loc_short == target_enemy_loc_short
@@ -219,8 +216,8 @@ def get_nearest_enemy_units(
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_specific_enemy_loc)
if path_short_names:
enemy_paths.append((enemy_unit_str, path_short_names))
-
- enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
+
+ enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
return enemy_paths[:n]
@@ -242,10 +239,10 @@ def get_nearest_uncontrolled_scs(
"""
results: List[Tuple[str, int, List[str]]] = []
- for sc_short in game_map.scs: # all SC province codes
+ for sc_short in game_map.scs: # all SC province codes
controller = get_sc_controller(game_map, board_state, sc_short)
if controller == power_name:
- continue # already ours
+ continue # already ours
# helper for BFS target test
def is_target(loc_short: str, _state: BoardState) -> bool:
@@ -260,9 +257,9 @@ def get_nearest_uncontrolled_scs(
is_target,
)
if not path:
- continue # unreachable
+ continue # unreachable
- distance = len(path) - 1 # moves needed
+ distance = len(path) - 1 # moves needed
# skip distance 0 (same province) and 1 (adjacent)
if distance <= 1:
@@ -277,11 +274,11 @@ def get_nearest_uncontrolled_scs(
def get_adjacent_territory_details(
- game_map: GameMap,
- board_state: BoardState,
- unit_loc_full: str, # The location of the unit whose adjacencies we're checking
- unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
- graph: Dict[str, Dict[str, List[str]]]
+ game_map: GameMap,
+ board_state: BoardState,
+ unit_loc_full: str, # The location of the unit whose adjacencies we're checking
+ unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
+ graph: Dict[str, Dict[str, List[str]]],
) -> str:
"""Generates a string describing adjacent territories and units that can interact with them."""
output_lines: List[str] = []
@@ -291,37 +288,37 @@ def get_adjacent_territory_details(
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
- if '/' in unit_loc_short:
+ if "/" in unit_loc_short:
unit_loc_short = unit_loc_short[:3]
- if '/' not in unit_loc_full:
+ if "/" not in unit_loc_full:
unit_loc_short = unit_loc_full[:3]
else:
unit_loc_short = unit_loc_full[:3]
- adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
+ adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
- processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
+ processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
- for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
+ for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
# adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed
- if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
+ if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
continue
processed_adj_provinces.add(adj_loc_short)
- adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
- if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
- adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
- elif adj_loc_type == 'WATER':
- adj_loc_type_display = 'WATER'
- else: # SHUT etc.
+ adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
+ if adj_loc_type == "COAST" or adj_loc_type == "LAND":
+ adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
+ elif adj_loc_type == "WATER":
+ adj_loc_type_display = "WATER"
+ else: # SHUT etc.
adj_loc_type_display = adj_loc_type
-
+
line = f" {adj_loc_short} ({adj_loc_type_display})"
-
+
sc_controller = get_sc_controller(game_map, board_state, adj_loc_short)
if sc_controller:
line += f" SC Control: {sc_controller}"
-
+
unit_in_adj_loc = get_unit_at_location(board_state, adj_loc_short)
if unit_in_adj_loc:
line += f" Units: {unit_in_adj_loc}"
@@ -330,15 +327,14 @@ def get_adjacent_territory_details(
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
# A true "can support/move to" would require checking possible orders of those further units.
# further_adj_provinces are short names from the graph
- further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
- graph.get(adj_loc_short, {}).get('FLEET', [])
-
+ further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
+
supporting_units_info = []
processed_further_provinces = set()
for further_adj_loc_short in further_adj_provinces_short:
# further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3]
# No conversion needed, it's already short
- if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
+ if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
continue
if further_adj_loc_short in processed_further_provinces:
continue
@@ -354,14 +350,14 @@ def get_adjacent_territory_details(
temp_unit = get_unit_at_location(board_state, fv_further)
if temp_unit:
unit_in_further_loc = temp_unit
- break # Found a unit in one of the coasts/base
-
+ break # Found a unit in one of the coasts/base
+
# if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short:
# unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short)
-
+
if unit_in_further_loc:
supporting_units_info.append(unit_in_further_loc)
-
+
if supporting_units_info:
output_lines.append(f" => Can support/move to: {', '.join(sorted(list(set(supporting_units_info))))}")
@@ -378,30 +374,30 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
board_state: BoardState = game.get_state()
game_map: GameMap = game.map
graph = build_diplomacy_graph(game_map)
-
+
final_context_lines: List[str] = [""]
# Iterate through units that have orders (keys of possible_orders_for_power are unit locations)
for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items():
unit_str_full = get_unit_at_location(board_state, unit_loc_full)
- if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
- continue
+ if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
+ continue
- unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
- unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
+ unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
+ unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
- if loc_type_short == 'COAST' or loc_type_short == 'LAND':
- loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
+ if loc_type_short == "COAST" or loc_type_short == "LAND":
+ loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
else:
loc_type_display = loc_type_short
current_unit_lines: List[str] = []
current_unit_lines.append(f' ')
-
+
# Unit Information section
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
if sc_owner_at_loc == power_name:
@@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
header_content += f" (SC controlled by {sc_owner_at_loc})"
current_unit_lines.append(f" {header_content}")
current_unit_lines.append(f" Units present: {unit_str_full}")
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
# Possible moves section
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
current_unit_lines.append(" Possible moves:")
for order_str in unit_specific_possible_orders:
current_unit_lines.append(f" {order_str}")
- current_unit_lines.append(' ')
-
+ current_unit_lines.append(" ")
+
# Nearest enemy units section
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
if enemy_units_info:
current_unit_lines.append(" Nearest units (not ours):")
for enemy_unit_str, enemy_path_short in enemy_units_info:
- current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
+ current_unit_lines.append(
+ f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
+ )
else:
current_unit_lines.append(" Nearest units (not ours): None found")
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
# Nearest supply centers (not controlled by us) section
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
if uncontrolled_scs_info:
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
- current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
+ current_unit_lines.append(
+ f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
+ )
else:
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
# Adjacent territories details section
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
- current_unit_lines.append(' ')
+ current_unit_lines.append(" ")
if adj_details_str:
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
# Assuming adj_details_str is already formatted with newlines and indentation for its content
# We might need to indent adj_details_str if it's a single block of text
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
- if '\n' in adj_details_str:
- indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
+ if "\n" in adj_details_str:
+ indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
current_unit_lines.append(indented_adj_details)
else:
- current_unit_lines.append(f" {adj_details_str}")
+ current_unit_lines.append(f" {adj_details_str}")
else:
- current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
- current_unit_lines.append(' ')
-
- current_unit_lines.append(' ')
+ current_unit_lines.append(
+ " Adjacent territories: None relevant or all are empty/uncontested by direct threats."
+ ) # Added more descriptive else
+ current_unit_lines.append(" ")
+
+ current_unit_lines.append(" ")
final_context_lines.extend(current_unit_lines)
final_context_lines.append("")
@@ -469,16 +471,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
# Regex and tiny helpers
# ---------------------------------------------------------------------------
-import re
from typing import Tuple, List, Dict, Optional, Any
# ── order-syntax matchers ─────────────────────────────────────────────────
_SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$")
-_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
-_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
-_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
+_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
+_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
+_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
-def _is_hold_order(order: str) -> bool: # NEW
+
+def _is_hold_order(order: str) -> bool: # NEW
return bool(_HOLD_RE.match(order.strip()))
@@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
"""Trim & uppercase for reliable comparisons."""
return name.strip().upper()
+
def _is_simple_move(order: str) -> bool:
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
+
def _is_retreat_order(order: str) -> bool:
return bool(_RETREAT_RE.match(order.strip()))
+
def _is_adjust_order(order: str) -> bool:
return bool(_ADJUST_RE.match(order.strip()))
+
def _split_move(order: str) -> Tuple[str, str]:
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
unit_part, dest = order.split(" - ")
@@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
# Gather *all* friendly support orders for a given move
# ---------------------------------------------------------------------------
+
def _all_support_examples(
mover: str,
dest: str,
@@ -526,6 +533,7 @@ def _all_support_examples(
return supports
+
def _all_support_hold_examples(
holder: str,
all_orders: Dict[str, List[str]],
@@ -538,7 +546,7 @@ def _all_support_hold_examples(
supports: List[str] = []
for loc, orders in all_orders.items():
- if holder.endswith(loc): # skip the holding unit
+ if holder.endswith(loc): # skip the holding unit
continue
for o in orders:
if o.strip().endswith(target):
@@ -550,6 +558,7 @@ def _all_support_hold_examples(
# Province-type resolver (handles short codes, coasts, seas)
# ---------------------------------------------------------------------------
+
def _province_type_display(game_map, prov_short: str) -> str:
"""
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
@@ -575,7 +584,7 @@ def _dest_occupancy_desc(
board_state,
our_power: str,
) -> str:
- """ '(occupied by X)', '(occupied by X — you!)', or '(unoccupied)' """
+ """'(occupied by X)', '(occupied by X — you!)', or '(unoccupied)'"""
occupant: Optional[str] = None
for full in game_map.loc_coasts.get(dest_short, [dest_short]):
u = get_unit_at_location(board_state, full)
@@ -593,6 +602,7 @@ def _dest_occupancy_desc(
# Adjacent-territory lines (used by movement-phase builder)
# ---------------------------------------------------------------------------
+
def _adjacent_territory_lines(
graph,
game_map,
@@ -633,13 +643,9 @@ def _adjacent_territory_lines(
pwr = unit_here.split(" ")[-1].strip("()")
if pwr == our_power:
friend_descr = unit_here.split(" (")[0]
- lines.append(
- f"{indent2}Support hold: {mover_descr} S {friend_descr}"
- )
+ lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
else:
- lines.append(
- f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa"
- )
+ lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa")
return lines
@@ -648,6 +654,7 @@ def _adjacent_territory_lines(
# Movement-phase generator (UNCHANGED LOGIC)
# ---------------------------------------------------------------------------
+
def _generate_rich_order_context_movement(
game: Any,
power_name: str,
@@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
if _norm_power(unit_power) != me:
continue # Skip units that aren’t ours
- mover_descr, _ = _split_move(
- f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
- )
+ mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
prov_type_disp = _province_type_display(game_map, prov_short)
sc_tag = " (SC)" if prov_short in game_map.scs else ""
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
- owner_line = (
- f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
- )
+ owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
ind = " "
block: List[str] = [f""]
@@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
# ----- Adjacent territories -----
block.append("# Adjacent territories:")
- block.extend(
- _adjacent_territory_lines(
- graph, game_map, board_state,
- unit_loc_full, mover_descr, power_name
- )
- )
+ block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
# ----- Nearest enemy units -----
block.append("# Nearest units (not ours):")
enemies = get_nearest_enemy_units(
- board_state, graph, game_map,
- power_name, unit_loc_full,
+ board_state,
+ graph,
+ game_map,
+ power_name,
+ unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
# ----- Nearest uncontrolled SCs -----
block.append("# Nearest supply centers (not controlled by us):")
scs = get_nearest_uncontrolled_scs(
- game_map, board_state, graph,
- power_name, unit_loc_full,
+ game_map,
+ board_state,
+ graph,
+ power_name,
+ unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@@ -729,7 +733,7 @@ def _generate_rich_order_context_movement(
block.append(f"# Possible {mover_descr} unit movements & supports:")
simple_moves = [o for o in orders if _is_simple_move(o)]
- hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
+ hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
if not simple_moves and not hold_orders:
block.append(f"{ind}None")
@@ -737,22 +741,19 @@ def _generate_rich_order_context_movement(
# ---- Moves (same behaviour as before) ----
for mv in simple_moves:
mover, dest = _split_move(mv)
- occ = _dest_occupancy_desc(
- dest.split("/")[0][:3], game_map, board_state, power_name
- )
+ occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
block.append(f"{ind}{mv} {occ}")
for s in _all_support_examples(mover, dest, possible_orders_for_power):
- block.append(f"{ind*2}Available Support: {s}")
+ block.append(f"{ind * 2}Available Support: {s}")
# ---- Holds (new) ----
for hd in hold_orders:
- holder = hd.split(" H")[0] # e.g., 'F DEN'
+ holder = hd.split(" H")[0] # e.g., 'F DEN'
block.append(f"{ind}{hd}")
for s in _all_support_hold_examples(holder, possible_orders_for_power):
- block.append(f"{ind*2}Available Support: {s}")
-
+ block.append(f"{ind * 2}Available Support: {s}")
block.append(f"")
blocks.append("\n".join(block))
@@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
# Retreat-phase builder – echo orders verbatim, no tags
# ---------------------------------------------------------------------------
+
def _generate_rich_order_context_retreat(
game: Any,
power_name: str,
@@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
return "\n".join(lines) if lines else "(No dislodged units)"
-
# ---------------------------------------------------------------------------
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
# ---------------------------------------------------------------------------
+
def _generate_rich_order_context_adjustment(
game: Any,
power_name: str,
@@ -800,12 +802,12 @@ def _generate_rich_order_context_adjustment(
* No wrapper tags.
"""
board_state = game.get_state()
- sc_owned = len(board_state.get("centers", {}).get(power_name, []))
- units_num = len(board_state.get("units", {}).get(power_name, []))
- delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
+ sc_owned = len(board_state.get("centers", {}).get(power_name, []))
+ units_num = len(board_state.get("units", {}).get(power_name, []))
+ delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
# ----- summary line ----------------------------------------------------
- if delta > 0:
+ if delta > 0:
summary = f"Builds available: {delta}"
elif delta < 0:
summary = f"Disbands required: {-delta}"
@@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
return "\n".join(lines) if len(lines) > 1 else summary
-
-
-
# ---------------------------------------------------------------------------
# Phase-dispatch wrapper (public entry point)
# ---------------------------------------------------------------------------
+
def generate_rich_order_context(
game: Any,
power_name: str,
@@ -846,22 +846,14 @@ def generate_rich_order_context(
phase_type = game.current_short_phase[-1]
- if phase_type == "M": # Movement
- return _generate_rich_order_context_movement(
- game, power_name, possible_orders_for_power
- )
+ if phase_type == "M": # Movement
+ return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
- if phase_type == "R": # Retreat
- return _generate_rich_order_context_retreat(
- game, power_name, possible_orders_for_power
- )
+ if phase_type == "R": # Retreat
+ return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
- if phase_type == "A": # Adjustment (build / disband)
- return _generate_rich_order_context_adjustment(
- game, power_name, possible_orders_for_power
- )
+ if phase_type == "A": # Adjustment (build / disband)
+ return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
# Fallback – treat unknown formats as movement
- return _generate_rich_order_context_movement(
- game, power_name, possible_orders_for_power
- )
+ return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
diff --git a/ai_diplomacy/prompt_constructor.py b/ai_diplomacy/prompt_constructor.py
index d77cffb..07d4f25 100644
--- a/ai_diplomacy/prompt_constructor.py
+++ b/ai_diplomacy/prompt_constructor.py
@@ -11,7 +11,6 @@ from .possible_order_context import (
generate_rich_order_context,
generate_rich_order_context_xml,
)
-import os
from .game_history import GameHistory # Assuming GameHistory is correctly importable
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
@@ -230,4 +229,3 @@ def construct_order_generation_prompt(
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
return final_prompt
-
diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py
index b1808b9..e47643f 100644
--- a/ai_diplomacy/utils.py
+++ b/ai_diplomacy/utils.py
@@ -16,7 +16,7 @@ from ..config import config
if TYPE_CHECKING:
from .clients import BaseModelClient
# If DiplomacyAgent is used for type hinting for an 'agent' parameter:
- # from .agent import DiplomacyAgent
+ # from .agent import DiplomacyAgent
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
@@ -32,12 +32,12 @@ def atomic_write_json(data: dict, filepath: str):
dir_name = os.path.dirname(filepath)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
-
+
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
- with open(temp_filepath, 'w', encoding='utf-8') as f:
+ with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
-
+
# Atomically rename the temporary file to the final destination
os.rename(temp_filepath, filepath)
except Exception as e:
@@ -55,15 +55,15 @@ def assign_models_to_powers() -> Dict[str, str]:
Example usage: define which model each power uses.
Return a dict: { power_name: model_id, ... }
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
- Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
- claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
- gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
+ Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini,
+ claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219
+ gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25,
deepseek-chat, deepseek-reasoner
openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free,
- openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
+ openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free,
openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20
"""
-
+
# POWER MODELS
"""
return {
@@ -77,13 +77,13 @@ def assign_models_to_powers() -> Dict[str, str]:
}
"""
# TEST MODELS
-
+
return {
"AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
- "ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
+ "ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
"TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct",
}
@@ -92,30 +92,26 @@ def assign_models_to_powers() -> Dict[str, str]:
def get_special_models() -> Dict[str, str]:
"""
Define models for special purposes like phase summaries and formatting.
-
+
These can be overridden via environment variables:
- AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3")
- AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17")
-
+
Returns:
dict: {
"phase_summary": model for generating narrative phase summaries,
"formatter": model for formatting natural language to JSON
}
-
+
Examples:
# Use Claude for phase summaries
export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022"
-
+
# Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
"""
- return {
- "phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
- "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
- }
-
-
+ return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
+
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
"""
@@ -132,7 +128,7 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
async def get_valid_orders(
game: Game,
- client, # BaseModelClient instance
+ client, # BaseModelClient instance
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
@@ -170,20 +166,24 @@ async def get_valid_orders(
)
invalid_info: list[str] = []
- valid: list[str] = []
+ valid: list[str] = []
invalid: list[str] = []
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
- logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
- power_name, raw_orders)
+ logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
- "AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
- "GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
+ "AUT": "AUSTRIA",
+ "ENG": "ENGLAND",
+ "FRA": "FRANCE",
+ "GER": "GERMANY",
+ "ITA": "ITALY",
+ "RUS": "RUSSIA",
+ "TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
@@ -202,16 +202,16 @@ async def get_valid_orders(
game.set_orders(engine_power, [upper])
normed = game.get_orders(engine_power)
- if normed: # accepted
+ if normed: # accepted
valid.append(normed[0])
- else: # rejected
+ else: # rejected
invalid.append(upper)
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
game.clear_orders(engine_power) # clean slate for main engine flow
# ── 4. Legacy logging & stats updates ──────────────────────
- if invalid_info: # at least one bad move
+ if invalid_info: # at least one bad move
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
model_error_stats[client.model_name]["order_decoding_errors"] += 1
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
@@ -226,8 +226,6 @@ async def get_valid_orders(
return {"valid": valid, "invalid": invalid}
-
-
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
@@ -325,18 +323,18 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
4. Otherwise → join the package’s default prompts dir.
"""
- if os.path.isabs(filename): # rule 1
+ if os.path.isabs(filename): # rule 1
prompt_path = filename
- elif os.path.dirname(filename): # rule 2 (has slash)
+ elif os.path.dirname(filename): # rule 2 (has slash)
# If it's a relative path with directory, join with prompts_dir if provided
if prompts_dir:
prompt_path = os.path.join(prompts_dir, filename)
else:
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
- elif prompts_dir: # rule 3
+ elif prompts_dir: # rule 3
prompt_path = os.path.join(prompts_dir, filename)
- else: # rule 4
+ else: # rule 4
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_path = os.path.join(default_dir, filename)
@@ -348,16 +346,14 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
return ""
-
-
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
model_name: str,
- power_name: Optional[str], # Optional for non-power-specific calls like summary
+ power_name: Optional[str], # Optional for non-power-specific calls like summary
phase: str,
response_type: str,
- raw_input_prompt: str, # Added new parameter for the raw input
+ raw_input_prompt: str, # Added new parameter for the raw input
raw_response: str,
success: str, # Changed from bool to str
):
@@ -365,8 +361,8 @@ def log_llm_response(
try:
# Ensure the directory exists
log_dir = os.path.dirname(log_file_path)
- if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
- os.makedirs(log_dir, exist_ok=True)
+ if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
+ os.makedirs(log_dir, exist_ok=True)
# Check if file exists and has content to determine if we need headers
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
@@ -374,34 +370,38 @@ def log_llm_response(
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
- quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
- escapechar='\\') # Use backslash for escaping
+ writer = csv.DictWriter(
+ csvfile,
+ fieldnames=fieldnames,
+ quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
+ escapechar="\\",
+ ) # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
- writer.writerow({
- "model": model_name,
- "power": power_name if power_name else "game", # Use 'game' if no specific power
- "phase": phase,
- "response_type": response_type,
- "raw_input": raw_input_prompt, # Added raw_input to the row
- "raw_response": raw_response,
- "success": success,
- })
+ writer.writerow(
+ {
+ "model": model_name,
+ "power": power_name if power_name else "game", # Use 'game' if no specific power
+ "phase": phase,
+ "response_type": response_type,
+ "raw_input": raw_input_prompt, # Added raw_input to the row
+ "raw_response": raw_response,
+ "success": success,
+ }
+ )
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
-
async def run_llm_and_log(
client: "BaseModelClient",
prompt: str,
- log_file_path: str, # Kept for context, but not used for logging here
+ log_file_path: str, # Kept for context, but not used for logging here
power_name: Optional[str], # Kept for context, but not used for logging here
- phase: str, # Kept for context, but not used for logging here
- response_type: str, # Kept for context, but not used for logging here
+ phase: str, # Kept for context, but not used for logging here
+ response_type: str, # Kept for context, but not used for logging here
temperature: float = 0.0,
*,
attempts: int = 5,
@@ -428,43 +428,34 @@ async def run_llm_and_log(
except Exception as e:
if attempt == attempts - 1:
logger.error(
- f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
- f"in phase {phase}: {e}",
+ f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
exc_info=True,
)
# Back-off before the next attempt (unless this was the last)
- delay = backoff_base * (backoff_factor ** attempt) + random.uniform(0, jitter)
+ delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
await asyncio.sleep(delay)
return raw_response
-
-# This generates a few lines of random alphanum chars to inject into the
-# system prompt. This lets us use temp=0 while still getting variation
+# This generates a few lines of random alphanum chars to inject into the
+# system prompt. This lets us use temp=0 while still getting variation
# between trials.
-# Temp=0 is important for better performance on deciding moves, and to
+# Temp=0 is important for better performance on deciding moves, and to
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
- # Generate x lines of y random alphanumeric characters
- seed_lines = [
- ''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
- for _ in range(n_lines)
- ]
- random_seed_block = (
- "\n" +
- "\n".join(seed_lines) +
- "\n"
- )
- return random_seed_block
+ # Generate x lines of y random alphanumeric characters
+ seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
+ random_seed_block = "\n" + "\n".join(seed_lines) + "\n"
+ return random_seed_block
def get_prompt_path(prompt_name: str) -> str:
"""Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting.
-
+
Args:
prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt")
-
+
Returns:
str: Either "unformatted/{prompt_name}" or just "{prompt_name}"
"""
@@ -472,36 +463,37 @@ def get_prompt_path(prompt_name: str) -> str:
return f"unformatted/{prompt_name}"
else:
return prompt_name
-
-def normalize_recipient_name(recipient: str) -> str:
- """Normalize recipient names to handle LLM typos and abbreviations."""
- if not recipient:
- return recipient
-
- recipient = recipient.upper().strip()
-
- # Handle common LLM typos and abbreviations found in data
- name_mapping = {
- 'EGMANY': 'GERMANY',
- 'GERMAN': 'GERMANY',
- 'UK': 'ENGLAND',
- 'BRIT': 'ENGLAND',
- 'ENGLAND': 'ENGLAND', # Keep as-is
- 'FRANCE': 'FRANCE', # Keep as-is
- 'GERMANY': 'GERMANY', # Keep as-is
- 'ITALY': 'ITALY', # Keep as-is
- 'AUSTRIA': 'AUSTRIA', # Keep as-is
- 'RUSSIA': 'RUSSIA', # Keep as-is
- 'TURKEY': 'TURKEY', # Keep as-is
- 'Germany': 'GERMANY',
- 'England': 'ENGLAND',
- 'France': 'FRANCE',
- 'Italy': 'ITALY',
- 'Russia': 'RUSSIA',
- 'Austria': 'AUSTRIA',
- 'Turkey': 'TURKEY',
- }
-
- normalized = name_mapping.get(recipient, recipient)
- return normalized
+
+def normalize_recipient_name(recipient: str) -> str:
+ """Normalize recipient names to handle LLM typos and abbreviations."""
+ if not recipient:
+ return recipient
+
+ recipient = recipient.upper().strip()
+
+ # Handle common LLM typos and abbreviations found in data
+ name_mapping = {
+ "EGMANY": "GERMANY",
+ "GERMAN": "GERMANY",
+ "UK": "ENGLAND",
+ "BRIT": "ENGLAND",
+ "ENGLAND": "ENGLAND", # Keep as-is
+ "FRANCE": "FRANCE", # Keep as-is
+ "GERMANY": "GERMANY", # Keep as-is
+ "ITALY": "ITALY", # Keep as-is
+ "AUSTRIA": "AUSTRIA", # Keep as-is
+ "RUSSIA": "RUSSIA", # Keep as-is
+ "TURKEY": "TURKEY", # Keep as-is
+ "Germany": "GERMANY",
+ "England": "ENGLAND",
+ "France": "FRANCE",
+ "Italy": "ITALY",
+ "Russia": "RUSSIA",
+ "Austria": "AUSTRIA",
+ "Turkey": "TURKEY",
+ }
+
+ normalized = name_mapping.get(recipient, recipient)
+
+ return normalized