Adding PowerEnum as a seperate model

The PowerEnum correctly handles some misspellings. It can be easily
expanded to handle more within the _POWER_ALIASES dict.
This commit is contained in:
Tyler Marques 2025-07-03 12:06:47 -07:00
parent a241e34496
commit 540c2003e8
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
8 changed files with 1236 additions and 1053 deletions

View file

@ -5,13 +5,16 @@ import json
import re import re
import json_repair import json_repair
import json5 # More forgiving JSON parser import json5 # More forgiving JSON parser
import ast # For literal_eval import ast
from ..config import config
# Assuming BaseModelClient is importable from clients.py in the same directory # Assuming BaseModelClient is importable from clients.py in the same directory
from .clients import BaseModelClient, load_model_client from .clients import BaseModelClient, load_model_client
# Import load_prompt and the new logging wrapper from utils # Import load_prompt and the new logging wrapper from utils
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
from .prompt_constructor import build_context_prompt # Added import from .prompt_constructor import build_context_prompt # Added import
from .clients import GameHistory from .clients import GameHistory
from diplomacy import Game from diplomacy import Game
from .formatter import format_with_gemini_flash, FORMAT_ORDER_DIARY, FORMAT_NEGOTIATION_DIARY, FORMAT_STATE_UPDATE from .formatter import format_with_gemini_flash, FORMAT_ORDER_DIARY, FORMAT_NEGOTIATION_DIARY, FORMAT_STATE_UPDATE
@ -22,6 +25,7 @@ logger = logging.getLogger(__name__)
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"}) ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"] ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
# == New: Helper function to load prompt files reliably == # == New: Helper function to load prompt files reliably ==
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]: def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
"""Loads a prompt template from the prompts directory.""" """Loads a prompt template from the prompts directory."""
@ -31,10 +35,10 @@ def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optio
else: else:
# Construct path relative to this file's location # Construct path relative to this file's location
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts') default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename) filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
return f.read() return f.read()
except FileNotFoundError: except FileNotFoundError:
logger.error(f"Prompt file not found: {filepath}") logger.error(f"Prompt file not found: {filepath}")
@ -43,12 +47,14 @@ def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optio
logger.error(f"Error loading prompt file {filepath}: {e}") logger.error(f"Error loading prompt file {filepath}: {e}")
return None return None
class DiplomacyAgent: class DiplomacyAgent:
""" """
Represents a stateful AI agent playing as a specific power in Diplomacy. Represents a stateful AI agent playing as a specific power in Diplomacy.
It holds the agent's goals, relationships, and private journal, It holds the agent's goals, relationships, and private journal,
and uses a BaseModelClient instance to interact with the LLM. and uses a BaseModelClient instance to interact with the LLM.
""" """
def __init__( def __init__(
self, self,
power_name: str, power_name: str,
@ -91,29 +97,26 @@ class DiplomacyAgent:
# --- Load and set the appropriate system prompt --- # --- Load and set the appropriate system prompt ---
# Get the directory containing the current file (agent.py) # Get the directory containing the current file (agent.py)
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_path = os.path.join(current_dir, "prompts") default_prompts_path = os.path.join(current_dir, "prompts")
prompts_root = self.prompts_dir or default_prompts_path prompts_root = self.prompts_dir or default_prompts_path
power_prompt_name = f"{power_name.lower()}_system_prompt.txt" power_prompt_name = f"{power_name.lower()}_system_prompt.txt"
default_prompt_name = "system_prompt.txt" default_prompt_name = "system_prompt.txt"
power_prompt_path = os.path.join(prompts_root, power_prompt_name) power_prompt_path = os.path.join(prompts_root, power_prompt_name)
default_prompt_path = os.path.join(prompts_root, default_prompt_name) default_prompt_path = os.path.join(prompts_root, default_prompt_name)
system_prompt_content = load_prompt(power_prompt_path) system_prompt_content = load_prompt(power_prompt_path)
if not system_prompt_content: if not system_prompt_content:
logger.warning( logger.warning(f"Power-specific prompt not found at {power_prompt_path}. Falling back to default.")
f"Power-specific prompt not found at {power_prompt_path}. Falling back to default."
)
system_prompt_content = load_prompt(default_prompt_path) system_prompt_content = load_prompt(default_prompt_path)
if system_prompt_content: # Ensure we actually have content before setting
if system_prompt_content: # Ensure we actually have content before setting self.client.set_system_prompt(system_prompt_content)
self.client.set_system_prompt(system_prompt_content)
else: else:
logger.error(f"Could not load default system prompt either! Agent {power_name} may not function correctly.") logger.error(f"Could not load default system prompt either! Agent {power_name} may not function correctly.")
logger.info(f"Initialized DiplomacyAgent for {self.power_name} with goals: {self.goals}") logger.info(f"Initialized DiplomacyAgent for {self.power_name} with goals: {self.goals}")
self.add_journal_entry(f"Agent initialized. Initial Goals: {self.goals}") self.add_journal_entry(f"Agent initialized. Initial Goals: {self.goals}")
@ -131,11 +134,16 @@ class DiplomacyAgent:
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text) # Remove newlines before keys text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text) # Remove newlines before keys
# Fix specific patterns that cause trouble # Fix specific patterns that cause trouble
problematic_patterns = [ problematic_patterns = [
'negotiation_summary', 'relationship_updates', 'updated_relationships', "negotiation_summary",
'order_summary', 'goals', 'relationships', 'intent' "relationship_updates",
"updated_relationships",
"order_summary",
"goals",
"relationships",
"intent",
] ]
for pattern in problematic_patterns: for pattern in problematic_patterns:
text = re.sub(fr'\n\s*"{pattern}"', f'"{pattern}"', text) text = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', text)
# Try different patterns to extract JSON # Try different patterns to extract JSON
# Order matters - try most specific patterns first # Order matters - try most specific patterns first
@ -169,7 +177,9 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with pattern {pattern_idx}, match {match_idx}") logger.debug(f"[{self.power_name}] Successfully parsed JSON object with pattern {pattern_idx}, match {match_idx}")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Parsed JSON with pattern {pattern_idx}, match {match_idx}, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Parsed JSON with pattern {pattern_idx}, match {match_idx}, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
except json.JSONDecodeError as e_initial: except json.JSONDecodeError as e_initial:
logger.debug(f"[{self.power_name}] Standard JSON parse failed: {e_initial}") logger.debug(f"[{self.power_name}] Standard JSON parse failed: {e_initial}")
@ -179,16 +189,20 @@ class DiplomacyAgent:
cleaned_match_candidate = json_text cleaned_match_candidate = json_text
# Pattern 1: Removes 'Sentence.' when followed by ',', '}', or ']' # Pattern 1: Removes 'Sentence.' when followed by ',', '}', or ']'
cleaned_match_candidate = re.sub(r'\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=[,\}\]])', '', cleaned_match_candidate) cleaned_match_candidate = re.sub(
r"\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=[,\}\]])", "", cleaned_match_candidate
)
# Pattern 2: Removes 'Sentence.' when it's at the very end, before the final '}' of the current scope # Pattern 2: Removes 'Sentence.' when it's at the very end, before the final '}' of the current scope
cleaned_match_candidate = re.sub(r'\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=\s*\}\s*$)', '', cleaned_match_candidate) cleaned_match_candidate = re.sub(
r"\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=\s*\}\s*$)", "", cleaned_match_candidate
)
# Pattern 3: Fix for newlines and spaces before JSON keys (common problem with LLMs) # Pattern 3: Fix for newlines and spaces before JSON keys (common problem with LLMs)
cleaned_match_candidate = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', cleaned_match_candidate) cleaned_match_candidate = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', cleaned_match_candidate)
# Pattern 4: Fix trailing commas in JSON objects # Pattern 4: Fix trailing commas in JSON objects
cleaned_match_candidate = re.sub(r',\s*}', '}', cleaned_match_candidate) cleaned_match_candidate = re.sub(r",\s*}", "}", cleaned_match_candidate)
# Pattern 5: Handle specific known problematic patterns # Pattern 5: Handle specific known problematic patterns
for pattern in problematic_patterns: for pattern in problematic_patterns:
@ -211,7 +225,9 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json5") logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json5")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Parsed with json5, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Parsed with json5, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
except Exception as e: except Exception as e:
logger.debug(f"[{self.power_name}] json5 parse failed: {e}") logger.debug(f"[{self.power_name}] json5 parse failed: {e}")
@ -222,21 +238,23 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json-repair") logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json-repair")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Parsed with json-repair, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Parsed with json-repair, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
except Exception as e: except Exception as e:
logger.debug(f"[{self.power_name}] json-repair failed: {e}") logger.debug(f"[{self.power_name}] json-repair failed: {e}")
# New Strategy: Parse markdown-like key-value pairs # New Strategy: Parse markdown-like key-value pairs
# Example: **key:** value # Example: **key:** value
# This comes after trying to find fenced JSON blocks but before broad fallbacks. # This comes after trying to find fenced JSON blocks but before broad fallbacks.
if not matches: # Only try if previous patterns didn't yield a dict from a match if not matches: # Only try if previous patterns didn't yield a dict from a match
try: try:
markdown_data = {} markdown_data = {}
# Regex to find **key:** value, where value can be multi-line until next **key:** or end of string # Regex to find **key:** value, where value can be multi-line until next **key:** or end of string
md_pattern = r"\*\*(?P<key>[^:]+):\*\*\s*(?P<value>[\s\S]*?)(?=(?:\n\s*\*\*|$))" md_pattern = r"\*\*(?P<key>[^:]+):\*\*\s*(?P<value>[\s\S]*?)(?=(?:\n\s*\*\*|$))"
for match in re.finditer(md_pattern, text, re.DOTALL): for match in re.finditer(md_pattern, text, re.DOTALL):
key_name = match.group('key').strip() key_name = match.group("key").strip()
value_str = match.group('value').strip() value_str = match.group("value").strip()
try: try:
# Attempt to evaluate the value string as a Python literal # Attempt to evaluate the value string as a Python literal
# This handles lists, strings, numbers, booleans, None # This handles lists, strings, numbers, booleans, None
@ -245,11 +263,13 @@ class DiplomacyAgent:
except (ValueError, SyntaxError) as e_ast: except (ValueError, SyntaxError) as e_ast:
# If ast.literal_eval fails, it might be a plain string that doesn't look like a literal # If ast.literal_eval fails, it might be a plain string that doesn't look like a literal
# Or it could be genuinely malformed. We'll take it as a string if it's not empty. # Or it could be genuinely malformed. We'll take it as a string if it's not empty.
if value_str: # Only add if it's a non-empty string if value_str: # Only add if it's a non-empty string
markdown_data[key_name] = value_str # Store as string markdown_data[key_name] = value_str # Store as string
logger.debug(f"[{self.power_name}] ast.literal_eval failed for key '{key_name}', value '{value_str[:50]}...': {e_ast}. Storing as string if non-empty.") logger.debug(
f"[{self.power_name}] ast.literal_eval failed for key '{key_name}', value '{value_str[:50]}...': {e_ast}. Storing as string if non-empty."
)
if markdown_data: # If we successfully extracted any key-value pairs this way if markdown_data: # If we successfully extracted any key-value pairs this way
# Check if essential keys are present, if needed, or just return if any data found # Check if essential keys are present, if needed, or just return if any data found
# For now, if markdown_data is populated, we assume it's the intended structure. # For now, if markdown_data is populated, we assume it's the intended structure.
logger.debug(f"[{self.power_name}] Successfully parsed markdown-like key-value format. Data: {str(markdown_data)[:200]}") logger.debug(f"[{self.power_name}] Successfully parsed markdown-like key-value format. Data: {str(markdown_data)[:200]}")
@ -262,17 +282,13 @@ class DiplomacyAgent:
# Fallback: Try to find ANY JSON-like structure # Fallback: Try to find ANY JSON-like structure
try: try:
# Find the first { and last } # Find the first { and last }
start = text.find('{') start = text.find("{")
end = text.rfind('}') + 1 # Include the closing brace end = text.rfind("}") + 1 # Include the closing brace
if start != -1 and end > start: if start != -1 and end > start:
potential_json = text[start:end] potential_json = text[start:end]
# Try all parsers on this extracted text # Try all parsers on this extracted text
for parser_name, parser_func in [ for parser_name, parser_func in [("json", json.loads), ("json5", json5.loads), ("json_repair", json_repair.loads)]:
("json", json.loads),
("json5", json5.loads),
("json_repair", json_repair.loads)
]:
try: try:
cleaned = self._clean_json_text(potential_json) if parser_name == "json" else potential_json cleaned = self._clean_json_text(potential_json) if parser_name == "json" else potential_json
result = parser_func(cleaned) result = parser_func(cleaned)
@ -280,24 +296,28 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Fallback parse succeeded with {parser_name}, got dict.") logger.debug(f"[{self.power_name}] Fallback parse succeeded with {parser_name}, got dict.")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Fallback parse with {parser_name} succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Fallback parse with {parser_name} succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
except Exception as e: except Exception as e:
logger.debug(f"[{self.power_name}] Fallback {parser_name} failed: {e}") logger.debug(f"[{self.power_name}] Fallback {parser_name} failed: {e}")
# If standard parsers failed, try aggressive cleaning # If standard parsers failed, try aggressive cleaning
try: try:
# Remove common non-JSON text that LLMs might add # Remove common non-JSON text that LLMs might add
cleaned_text = re.sub(r'[^{}[\]"\',:.\d\w\s_-]', '', potential_json) cleaned_text = re.sub(r'[^{}[\]"\',:.\d\w\s_-]', "", potential_json)
# Replace single quotes with double quotes (common LLM error) # Replace single quotes with double quotes (common LLM error)
text_fixed = re.sub(r"'([^']*)':", r'"\1":', cleaned_text) text_fixed = re.sub(r"'([^']*)':", r'"\1":', cleaned_text)
text_fixed = re.sub(r': *\'([^\']*)\'', r': "\1"', text_fixed) text_fixed = re.sub(r": *\'([^\']*)\'", r': "\1"', text_fixed)
result = json.loads(text_fixed) result = json.loads(text_fixed)
if isinstance(result, dict): if isinstance(result, dict):
logger.debug(f"[{self.power_name}] Aggressive cleaning worked, got dict.") logger.debug(f"[{self.power_name}] Aggressive cleaning worked, got dict.")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Aggressive cleaning worked, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Aggressive cleaning worked, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
@ -311,7 +331,9 @@ class DiplomacyAgent:
logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, got dict.") logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, got dict.")
return result return result
else: else:
logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}") logger.warning(
f"[{self.power_name}] Last resort json-repair succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
)
# If even the last resort doesn't give a dict, return empty dict # If even the last resort doesn't give a dict, return empty dict
return {} return {}
except Exception as e: except Exception as e:
@ -324,8 +346,8 @@ class DiplomacyAgent:
return text return text
# Remove trailing commas # Remove trailing commas
text = re.sub(r',\s*}', '}', text) text = re.sub(r",\s*}", "}", text)
text = re.sub(r',\s*]', ']', text) text = re.sub(r",\s*]", "]", text)
# Fix newlines before JSON keys # Fix newlines before JSON keys
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text) text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text)
@ -334,19 +356,18 @@ class DiplomacyAgent:
text = re.sub(r"'(\w+)'\s*:", r'"\1":', text) text = re.sub(r"'(\w+)'\s*:", r'"\1":', text)
# Remove comments (if any) # Remove comments (if any)
text = re.sub(r'//.*$', '', text, flags=re.MULTILINE) text = re.sub(r"//.*$", "", text, flags=re.MULTILINE)
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) text = re.sub(r"/\*.*?\*/", "", text, flags=re.DOTALL)
# Fix unescaped quotes in values (basic attempt) # Fix unescaped quotes in values (basic attempt)
# This is risky but sometimes helps with simple cases # This is risky but sometimes helps with simple cases
text = re.sub(r':\s*"([^"]*)"([^",}\]]+)"', r': "\1\2"', text) text = re.sub(r':\s*"([^"]*)"([^",}\]]+)"', r': "\1\2"', text)
# Remove any BOM or zero-width spaces # Remove any BOM or zero-width spaces
text = text.replace('\ufeff', '').replace('\u200b', '') text = text.replace("\ufeff", "").replace("\u200b", "")
return text.strip() return text.strip()
def add_journal_entry(self, entry: str): def add_journal_entry(self, entry: str):
"""Adds a formatted entry string to the agent's private journal.""" """Adds a formatted entry string to the agent's private journal."""
# Ensure entry is a string # Ensure entry is a string
@ -358,7 +379,7 @@ class DiplomacyAgent:
def add_diary_entry(self, entry: str, phase: str): def add_diary_entry(self, entry: str, phase: str):
"""Adds a formatted entry to both the permanent and context diaries.""" """Adds a formatted entry to both the permanent and context diaries."""
if not isinstance(entry, str): if not isinstance(entry, str):
entry = str(entry) # Ensure it's a string entry = str(entry) # Ensure it's a string
formatted_entry = f"[{phase}] {entry}" formatted_entry = f"[{phase}] {entry}"
# Add to the permanent, unabridged record # Add to the permanent, unabridged record
@ -366,7 +387,9 @@ class DiplomacyAgent:
# Also add to the context diary, which will be periodically rebuilt # Also add to the context diary, which will be periodically rebuilt
self.private_diary.append(formatted_entry) self.private_diary.append(formatted_entry)
logger.info(f"[{self.power_name}] DIARY ENTRY ADDED for {phase}. Total full entries: {len(self.full_private_diary)}. New entry: {entry[:100]}...") logger.info(
f"[{self.power_name}] DIARY ENTRY ADDED for {phase}. Total full entries: {len(self.full_private_diary)}. New entry: {entry[:100]}..."
)
def format_private_diary_for_prompt(self) -> str: def format_private_diary_for_prompt(self) -> str:
""" """
@ -407,42 +430,43 @@ class DiplomacyAgent:
if not formatted_diary: if not formatted_diary:
return "(No diary entries to show)" return "(No diary entries to show)"
logger.info(f"[{self.power_name}] Formatted diary with {1 if consolidated_entry else 0} consolidated and {len(recent_entries)} recent entries. Preview: {formatted_diary[:250]}...") logger.info(
f"[{self.power_name}] Formatted diary with {1 if consolidated_entry else 0} consolidated and {len(recent_entries)} recent entries. Preview: {formatted_diary[:250]}..."
)
return formatted_diary return formatted_diary
# The consolidate_entire_diary method has been moved to ai_diplomacy/diary_logic.py # The consolidate_entire_diary method has been moved to ai_diplomacy/diary_logic.py
# to improve modularity and avoid circular dependencies. # to improve modularity and avoid circular dependencies.
# It is now called as `run_diary_consolidation(agent, game, ...)` from the main game loop. # It is now called as `run_diary_consolidation(agent, game, ...)` from the main game loop.
async def generate_negotiation_diary_entry(self, game: 'Game', game_history: GameHistory, log_file_path: str): async def generate_negotiation_diary_entry(self, game: "Game", game_history: GameHistory, log_file_path: str):
""" """
Generates a diary entry summarizing negotiations and updates relationships. Generates a diary entry summarizing negotiations and updates relationships.
This method now includes comprehensive LLM interaction logging. This method now includes comprehensive LLM interaction logging.
""" """
logger.info(f"[{self.power_name}] Generating negotiation diary entry for {game.current_short_phase}..." ) logger.info(f"[{self.power_name}] Generating negotiation diary entry for {game.current_short_phase}...")
full_prompt = "" # For logging in finally block full_prompt = "" # For logging in finally block
raw_response = "" # For logging in finally block raw_response = "" # For logging in finally block
success_status = "Failure: Initialized" # Default success_status = "Failure: Initialized" # Default
try: try:
# Load the prompt template file # Load the prompt template file
prompt_template_content = _load_prompt_file(get_prompt_path('negotiation_diary_prompt.txt'), prompts_dir=self.prompts_dir) prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template_content: if not prompt_template_content:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
success_status = "Failure: Prompt file not loaded" success_status = "Failure: Prompt file not loaded"
return # Exit early if prompt can't be loaded return # Exit early if prompt can't be loaded
# Prepare context for the prompt # Prepare context for the prompt
board_state_dict = game.get_state() board_state_dict = game.get_state()
board_state_str = f"Units: {board_state_dict.get('units', {})}, Centers: {board_state_dict.get('centers', {})}" board_state_str = f"Units: {board_state_dict.get('units', {})}, Centers: {board_state_dict.get('centers', {})}"
messages_this_round = game_history.get_messages_this_round( messages_this_round = game_history.get_messages_this_round(power_name=self.power_name, current_phase_name=game.current_short_phase)
power_name=self.power_name,
current_phase_name=game.current_short_phase
)
if not messages_this_round.strip() or messages_this_round.startswith("\n(No messages"): if not messages_this_round.strip() or messages_this_round.startswith("\n(No messages"):
messages_this_round = "(No messages involving your power this round that require deep reflection for diary. Focus on overall situation.)" messages_this_round = (
"(No messages involving your power this round that require deep reflection for diary. Focus on overall situation.)"
)
current_relationships_str = json.dumps(self.relationships) current_relationships_str = json.dumps(self.relationships)
current_goals_str = json.dumps(self.goals) current_goals_str = json.dumps(self.goals)
@ -462,28 +486,31 @@ class DiplomacyAgent:
# Do aggressive preprocessing of the template to fix the problematic patterns # Do aggressive preprocessing of the template to fix the problematic patterns
# This includes removing any newlines or whitespace before JSON keys that cause issues # This includes removing any newlines or whitespace before JSON keys that cause issues
for pattern in ['negotiation_summary', 'updated_relationships', 'relationship_updates', 'intent']: for pattern in ["negotiation_summary", "updated_relationships", "relationship_updates", "intent"]:
# Fix the "\n "key"" pattern that breaks .format() # Fix the "\n "key"" pattern that breaks .format()
prompt_template_content = re.sub( prompt_template_content = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', prompt_template_content)
fr'\n\s*"{pattern}"',
f'"{pattern}"',
prompt_template_content
)
# Escape all curly braces in JSON examples to prevent format() from interpreting them # Escape all curly braces in JSON examples to prevent format() from interpreting them
# First, temporarily replace the actual template variables # First, temporarily replace the actual template variables
temp_vars = ['power_name', 'current_phase', 'messages_this_round', 'agent_goals', temp_vars = [
'agent_relationships', 'board_state_str', 'ignored_messages_context'] "power_name",
"current_phase",
"messages_this_round",
"agent_goals",
"agent_relationships",
"board_state_str",
"ignored_messages_context",
]
for var in temp_vars: for var in temp_vars:
prompt_template_content = prompt_template_content.replace(f'{{{var}}}', f'<<{var}>>') prompt_template_content = prompt_template_content.replace(f"{{{var}}}", f"<<{var}>>")
# Now escape all remaining braces (which should be JSON) # Now escape all remaining braces (which should be JSON)
prompt_template_content = prompt_template_content.replace('{', '{{') prompt_template_content = prompt_template_content.replace("{", "{{")
prompt_template_content = prompt_template_content.replace('}', '}}') prompt_template_content = prompt_template_content.replace("}", "}}")
# Restore the template variables # Restore the template variables
for var in temp_vars: for var in temp_vars:
prompt_template_content = prompt_template_content.replace(f'<<{var}>>', f'{{{var}}}') prompt_template_content = prompt_template_content.replace(f"<<{var}>>", f"{{{var}}}")
# Create a dictionary with safe values for formatting # Create a dictionary with safe values for formatting
format_vars = { format_vars = {
@ -495,7 +522,7 @@ class DiplomacyAgent:
"agent_goals": current_goals_str, "agent_goals": current_goals_str,
"allowed_relationships_str": ", ".join(ALLOWED_RELATIONSHIPS), "allowed_relationships_str": ", ".join(ALLOWED_RELATIONSHIPS),
"private_diary_summary": formatted_diary, "private_diary_summary": formatted_diary,
"ignored_messages_context": ignored_context "ignored_messages_context": ignored_context,
} }
# Now try to use the template after preprocessing # Now try to use the template after preprocessing
@ -511,16 +538,15 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...") logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...") logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
raw_response = await run_llm_and_log( raw_response = await run_llm_and_log(
client=self.client, client=self.client,
prompt=full_prompt, prompt=full_prompt,
log_file_path=log_file_path, # Pass the main log file path log_file_path=log_file_path, # Pass the main log file path
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='negotiation_diary_raw', # For run_llm_and_log context response_type="negotiation_diary_raw", # For run_llm_and_log context
) )
logger.debug(f"[{self.power_name}] Raw negotiation diary response: {raw_response[:300]}...") logger.debug(f"[{self.power_name}] Raw negotiation diary response: {raw_response[:300]}...")
@ -528,14 +554,14 @@ class DiplomacyAgent:
parsed_data = None parsed_data = None
try: try:
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS # Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1": if config.USE_UNFORMATTED_PROMPTS:
# Format the natural language response into JSON # Format the natural language response into JSON
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
raw_response, raw_response,
FORMAT_NEGOTIATION_DIARY, FORMAT_NEGOTIATION_DIARY,
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
log_file_path=log_file_path log_file_path=log_file_path,
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -548,23 +574,23 @@ class DiplomacyAgent:
success_status = "Failure: JSONDecodeError" success_status = "Failure: JSONDecodeError"
# Continue without parsed_data, rely on diary_entry_text if available or just log failure # Continue without parsed_data, rely on diary_entry_text if available or just log failure
diary_entry_text = "(LLM diary entry generation or parsing failed.)" # Fallback diary_entry_text = "(LLM diary entry generation or parsing failed.)" # Fallback
relationships_updated = False relationships_updated = False
if parsed_data: if parsed_data:
# Fix 1: Be more robust about extracting the negotiation_summary field # Fix 1: Be more robust about extracting the negotiation_summary field
diary_text_candidate = None diary_text_candidate = None
for key in ['negotiation_summary', 'summary', 'diary_entry']: for key in ["negotiation_summary", "summary", "diary_entry"]:
if key in parsed_data and isinstance(parsed_data[key], str) and parsed_data[key].strip(): if key in parsed_data and isinstance(parsed_data[key], str) and parsed_data[key].strip():
diary_text_candidate = parsed_data[key].strip() diary_text_candidate = parsed_data[key].strip()
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for diary.") logger.info(f"[{self.power_name}] Successfully extracted '{key}' for diary.")
break break
if 'intent' in parsed_data: if "intent" in parsed_data:
if diary_text_candidate == None: if diary_text_candidate == None:
diary_text_candidate = parsed_data['intent'] diary_text_candidate = parsed_data["intent"]
else: else:
diary_text_candidate += '\nIntent: ' + parsed_data['intent'] diary_text_candidate += "\nIntent: " + parsed_data["intent"]
if diary_text_candidate: if diary_text_candidate:
diary_entry_text = diary_text_candidate diary_entry_text = diary_text_candidate
@ -574,7 +600,7 @@ class DiplomacyAgent:
# Fix 2: Be more robust about extracting relationship updates # Fix 2: Be more robust about extracting relationship updates
new_relationships = None new_relationships = None
for key in ['relationship_updates', 'updated_relationships', 'relationships']: for key in ["relationship_updates", "updated_relationships", "relationships"]:
if key in parsed_data and isinstance(parsed_data[key], dict): if key in parsed_data and isinstance(parsed_data[key], dict):
new_relationships = parsed_data[key] new_relationships = parsed_data[key]
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for relationship updates.") logger.info(f"[{self.power_name}] Successfully extracted '{key}' for relationship updates.")
@ -587,7 +613,7 @@ class DiplomacyAgent:
r_title = str(r).title() r_title = str(r).title()
if p_upper in ALL_POWERS and p_upper != self.power_name and r_title in ALLOWED_RELATIONSHIPS: if p_upper in ALL_POWERS and p_upper != self.power_name and r_title in ALLOWED_RELATIONSHIPS:
valid_new_rels[p_upper] = r_title valid_new_rels[p_upper] = r_title
elif p_upper != self.power_name: # Log invalid relationship for a valid power elif p_upper != self.power_name: # Log invalid relationship for a valid power
logger.warning(f"[{self.power_name}] Invalid relationship '{r}' for power '{p}' in diary update. Keeping old.") logger.warning(f"[{self.power_name}] Invalid relationship '{r}' for power '{p}' in diary update. Keeping old.")
if valid_new_rels: if valid_new_rels:
@ -595,15 +621,17 @@ class DiplomacyAgent:
for p_changed, new_r_val in valid_new_rels.items(): for p_changed, new_r_val in valid_new_rels.items():
old_r_val = self.relationships.get(p_changed, "Unknown") old_r_val = self.relationships.get(p_changed, "Unknown")
if old_r_val != new_r_val: if old_r_val != new_r_val:
logger.info(f"[{self.power_name}] Relationship with {p_changed} changing from {old_r_val} to {new_r_val} based on diary.") logger.info(
f"[{self.power_name}] Relationship with {p_changed} changing from {old_r_val} to {new_r_val} based on diary."
)
self.relationships.update(valid_new_rels) self.relationships.update(valid_new_rels)
relationships_updated = True relationships_updated = True
success_status = "Success: Applied diary data (relationships updated)" success_status = "Success: Applied diary data (relationships updated)"
else: else:
logger.info(f"[{self.power_name}] No valid relationship updates found in diary response.") logger.info(f"[{self.power_name}] No valid relationship updates found in diary response.")
if success_status == "Success: Parsed diary data": # If only parsing was successful before if success_status == "Success: Parsed diary data": # If only parsing was successful before
success_status = "Success: Parsed, no valid relationship updates" success_status = "Success: Parsed, no valid relationship updates"
elif new_relationships is not None: # It was provided but not a dict elif new_relationships is not None: # It was provided but not a dict
logger.warning(f"[{self.power_name}] 'updated_relationships' from diary LLM was not a dictionary: {type(new_relationships)}") logger.warning(f"[{self.power_name}] 'updated_relationships' from diary LLM was not a dictionary: {type(new_relationships)}")
# Add the generated (or fallback) diary entry # Add the generated (or fallback) diary entry
@ -622,26 +650,26 @@ class DiplomacyAgent:
# Add a fallback diary entry in case of general error # Add a fallback diary entry in case of general error
self.add_diary_entry(f"(Error generating diary entry: {type(e).__name__})", game.current_short_phase) self.add_diary_entry(f"(Error generating diary entry: {type(e).__name__})", game.current_short_phase)
finally: finally:
if log_file_path: # Ensure log_file_path is provided if log_file_path: # Ensure log_file_path is provided
log_llm_response( log_llm_response(
log_file_path=log_file_path, log_file_path=log_file_path,
model_name=self.client.model_name if self.client else "UnknownModel", model_name=self.client.model_name if self.client else "UnknownModel",
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase if game else "UnknownPhase", phase=game.current_short_phase if game else "UnknownPhase",
response_type="negotiation_diary", # Specific type for CSV logging response_type="negotiation_diary", # Specific type for CSV logging
raw_input_prompt=full_prompt, raw_input_prompt=full_prompt,
raw_response=raw_response, raw_response=raw_response,
success=success_status success=success_status,
) )
async def generate_order_diary_entry(self, game: 'Game', orders: List[str], log_file_path: str): async def generate_order_diary_entry(self, game: "Game", orders: List[str], log_file_path: str):
""" """
Generates a diary entry reflecting on the decided orders. Generates a diary entry reflecting on the decided orders.
""" """
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...") logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
# Load the prompt template # Load the prompt template
prompt_template = _load_prompt_file(get_prompt_path('order_diary_prompt.txt'), prompts_dir=self.prompts_dir) prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
return return
@ -656,23 +684,22 @@ class DiplomacyAgent:
# Do aggressive preprocessing on the template file # Do aggressive preprocessing on the template file
# Fix any whitespace or formatting issues that could break .format() # Fix any whitespace or formatting issues that could break .format()
for pattern in ['order_summary']: for pattern in ["order_summary"]:
prompt_template = re.sub(fr'\n\s*"{pattern}"', f'"{pattern}"', prompt_template) prompt_template = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', prompt_template)
# Escape all curly braces in JSON examples to prevent format() from interpreting them # Escape all curly braces in JSON examples to prevent format() from interpreting them
# First, temporarily replace the actual template variables # First, temporarily replace the actual template variables
temp_vars = ['power_name', 'current_phase', 'orders_list_str', 'board_state_str', temp_vars = ["power_name", "current_phase", "orders_list_str", "board_state_str", "agent_goals", "agent_relationships"]
'agent_goals', 'agent_relationships']
for var in temp_vars: for var in temp_vars:
prompt_template = prompt_template.replace(f'{{{var}}}', f'<<{var}>>') prompt_template = prompt_template.replace(f"{{{var}}}", f"<<{var}>>")
# Now escape all remaining braces (which should be JSON) # Now escape all remaining braces (which should be JSON)
prompt_template = prompt_template.replace('{', '{{') prompt_template = prompt_template.replace("{", "{{")
prompt_template = prompt_template.replace('}', '}}') prompt_template = prompt_template.replace("}", "}}")
# Restore the template variables # Restore the template variables
for var in temp_vars: for var in temp_vars:
prompt_template = prompt_template.replace(f'<<{var}>>', f'{{{var}}}') prompt_template = prompt_template.replace(f"<<{var}>>", f"{{{var}}}")
# Create a dictionary of variables for template formatting # Create a dictionary of variables for template formatting
format_vars = { format_vars = {
@ -681,7 +708,7 @@ class DiplomacyAgent:
"orders_list_str": orders_list_str, "orders_list_str": orders_list_str,
"board_state_str": board_state_str, "board_state_str": board_state_str,
"agent_goals": goals_str, "agent_goals": goals_str,
"agent_relationships": relationships_str "agent_relationships": relationships_str,
} }
# Try to use the template with proper formatting # Try to use the template with proper formatting
@ -694,10 +721,8 @@ class DiplomacyAgent:
logger.debug(f"[{self.power_name}] Order diary prompt:\n{prompt[:300]}...") logger.debug(f"[{self.power_name}] Order diary prompt:\n{prompt[:300]}...")
response_data = None response_data = None
raw_response = None # Initialize raw_response raw_response = None # Initialize raw_response
try: try:
raw_response = await run_llm_and_log( raw_response = await run_llm_and_log(
client=self.client, client=self.client,
@ -705,12 +730,12 @@ class DiplomacyAgent:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='order_diary', response_type="order_diary",
) )
success_status = "FALSE" success_status = "FALSE"
response_data = None response_data = None
actual_diary_text = None # Variable to hold the final diary text actual_diary_text = None # Variable to hold the final diary text
if raw_response: if raw_response:
try: try:
@ -718,11 +743,7 @@ class DiplomacyAgent:
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1": if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
# Format the natural language response into JSON # Format the natural language response into JSON
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
raw_response, raw_response, FORMAT_ORDER_DIARY, power_name=self.power_name, phase=game.current_short_phase, log_file_path=log_file_path
FORMAT_ORDER_DIARY,
power_name=self.power_name,
phase=game.current_short_phase,
log_file_path=log_file_path
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -737,7 +758,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Successfully extracted 'order_summary' for order diary entry.") logger.info(f"[{self.power_name}] Successfully extracted 'order_summary' for order diary entry.")
else: else:
logger.warning(f"[{self.power_name}] 'order_summary' missing, invalid, or empty. Value was: {diary_text_candidate}") logger.warning(f"[{self.power_name}] 'order_summary' missing, invalid, or empty. Value was: {diary_text_candidate}")
success_status = "FALSE" # Explicitly set false if not found or invalid success_status = "FALSE" # Explicitly set false if not found or invalid
else: else:
# response_data is None (JSON parsing failed) # response_data is None (JSON parsing failed)
logger.warning(f"[{self.power_name}] Failed to parse JSON from order diary LLM response.") logger.warning(f"[{self.power_name}] Failed to parse JSON from order diary LLM response.")
@ -751,33 +772,35 @@ class DiplomacyAgent:
model_name=self.client.model_name, model_name=self.client.model_name,
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='order_diary', response_type="order_diary",
raw_input_prompt=prompt, # ENSURED raw_input_prompt=prompt, # ENSURED
raw_response=raw_response if raw_response else "", raw_response=raw_response if raw_response else "",
success=success_status success=success_status,
) )
if success_status == "TRUE" and actual_diary_text: if success_status == "TRUE" and actual_diary_text:
self.add_diary_entry(actual_diary_text, game.current_short_phase) self.add_diary_entry(actual_diary_text, game.current_short_phase)
logger.info(f"[{self.power_name}] Order diary entry generated and added.") logger.info(f"[{self.power_name}] Order diary entry generated and added.")
else: else:
fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (LLM failed to generate a specific diary entry)" fallback_diary = (
f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (LLM failed to generate a specific diary entry)"
)
self.add_diary_entry(fallback_diary, game.current_short_phase) self.add_diary_entry(fallback_diary, game.current_short_phase)
logger.warning(f"[{self.power_name}] Failed to generate specific order diary entry. Added fallback.") logger.warning(f"[{self.power_name}] Failed to generate specific order diary entry. Added fallback.")
except Exception as e: except Exception as e:
# Ensure prompt is defined or handled if it might not be (it should be in this flow) # Ensure prompt is defined or handled if it might not be (it should be in this flow)
current_prompt = prompt if 'prompt' in locals() else "[prompt_unavailable_in_exception]" current_prompt = prompt if "prompt" in locals() else "[prompt_unavailable_in_exception]"
current_raw_response = raw_response if 'raw_response' in locals() and raw_response is not None else f"Error: {e}" current_raw_response = raw_response if "raw_response" in locals() and raw_response is not None else f"Error: {e}"
log_llm_response( log_llm_response(
log_file_path=log_file_path, log_file_path=log_file_path,
model_name=self.client.model_name if hasattr(self, 'client') else "UnknownModel", model_name=self.client.model_name if hasattr(self, "client") else "UnknownModel",
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase if 'game' in locals() and hasattr(game, 'current_short_phase') else "order_phase", phase=game.current_short_phase if "game" in locals() and hasattr(game, "current_short_phase") else "order_phase",
response_type='order_diary_exception', response_type="order_diary_exception",
raw_input_prompt=current_prompt, # ENSURED (using current_prompt for safety) raw_input_prompt=current_prompt, # ENSURED (using current_prompt for safety)
raw_response=current_raw_response, raw_response=current_raw_response,
success="FALSE" success="FALSE",
) )
fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (Critical error in diary generation process)" fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (Critical error in diary generation process)"
self.add_diary_entry(fallback_diary, game.current_short_phase) self.add_diary_entry(fallback_diary, game.current_short_phase)
@ -785,12 +808,7 @@ class DiplomacyAgent:
# Rest of the code remains the same # Rest of the code remains the same
async def generate_phase_result_diary_entry( async def generate_phase_result_diary_entry(
self, self, game: "Game", game_history: "GameHistory", phase_summary: str, all_orders: Dict[str, List[str]], log_file_path: str
game: 'Game',
game_history: 'GameHistory',
phase_summary: str,
all_orders: Dict[str, List[str]],
log_file_path: str
): ):
""" """
Generates a diary entry analyzing the actual phase results, Generates a diary entry analyzing the actual phase results,
@ -799,7 +817,7 @@ class DiplomacyAgent:
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...") logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
# Load the template # Load the template
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir) prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.") logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
return return
@ -841,7 +859,7 @@ class DiplomacyAgent:
your_negotiations=your_negotiations, your_negotiations=your_negotiations,
pre_phase_relationships=relationships_str, pre_phase_relationships=relationships_str,
agent_goals=goals_str, agent_goals=goals_str,
your_actual_orders=your_orders_str your_actual_orders=your_orders_str,
) )
logger.debug(f"[{self.power_name}] Phase result diary prompt:\n{prompt[:500]}...") logger.debug(f"[{self.power_name}] Phase result diary prompt:\n{prompt[:500]}...")
@ -856,7 +874,7 @@ class DiplomacyAgent:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='phase_result_diary', response_type="phase_result_diary",
) )
if raw_response and raw_response.strip(): if raw_response and raw_response.strip():
@ -866,7 +884,9 @@ class DiplomacyAgent:
success_status = "TRUE" success_status = "TRUE"
logger.info(f"[{self.power_name}] Phase result diary entry generated and added.") logger.info(f"[{self.power_name}] Phase result diary entry generated and added.")
else: else:
fallback_diary = f"Phase {game.current_short_phase} completed. Orders executed as: {your_orders_str}. (Failed to generate detailed analysis)" fallback_diary = (
f"Phase {game.current_short_phase} completed. Orders executed as: {your_orders_str}. (Failed to generate detailed analysis)"
)
self.add_diary_entry(fallback_diary, game.current_short_phase) self.add_diary_entry(fallback_diary, game.current_short_phase)
logger.warning(f"[{self.power_name}] Empty response from LLM. Added fallback phase result diary.") logger.warning(f"[{self.power_name}] Empty response from LLM. Added fallback phase result diary.")
success_status = "FALSE" success_status = "FALSE"
@ -882,30 +902,32 @@ class DiplomacyAgent:
model_name=self.client.model_name, model_name=self.client.model_name,
power_name=self.power_name, power_name=self.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type='phase_result_diary', response_type="phase_result_diary",
raw_input_prompt=prompt, raw_input_prompt=prompt,
raw_response=raw_response, raw_response=raw_response,
success=success_status success=success_status,
) )
def log_state(self, prefix=""): def log_state(self, prefix=""):
logger.debug(f"[{self.power_name}] {prefix} State: Goals={self.goals}, Relationships={self.relationships}") logger.debug(f"[{self.power_name}] {prefix} State: Goals={self.goals}, Relationships={self.relationships}")
# Make this method async # Make this method async
async def analyze_phase_and_update_state(self, game: 'Game', board_state: dict, phase_summary: str, game_history: 'GameHistory', log_file_path: str): async def analyze_phase_and_update_state(
self, game: "Game", board_state: dict, phase_summary: str, game_history: "GameHistory", log_file_path: str
):
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM.""" """Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
# Use self.power_name internally # Use self.power_name internally
power_name = self.power_name power_name = self.power_name
current_phase = game.get_current_phase() # Get phase for logging current_phase = game.get_current_phase() # Get phase for logging
logger.info(f"[{power_name}] Analyzing phase {current_phase} outcome to update state...") logger.info(f"[{power_name}] Analyzing phase {current_phase} outcome to update state...")
self.log_state(f"Before State Update ({current_phase})") self.log_state(f"Before State Update ({current_phase})")
try: try:
# 1. Construct the prompt using the unformatted state update prompt file # 1. Construct the prompt using the unformatted state update prompt file
prompt_template = _load_prompt_file(get_prompt_path('state_update_prompt.txt'), prompts_dir=self.prompts_dir) prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
if not prompt_template: if not prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.") logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
return return
# Get previous phase safely from history # Get previous phase safely from history
if not game_history or not game_history.phases: if not game_history or not game_history.phases:
@ -913,7 +935,7 @@ class DiplomacyAgent:
return return
last_phase = game_history.phases[-1] last_phase = game_history.phases[-1]
last_phase_name = last_phase.name # Assuming phase object has a 'name' attribute last_phase_name = last_phase.name # Assuming phase object has a 'name' attribute
# Use the provided phase_summary parameter instead of retrieving it # Use the provided phase_summary parameter instead of retrieving it
last_phase_summary = phase_summary last_phase_summary = phase_summary
@ -929,13 +951,13 @@ class DiplomacyAgent:
context = build_context_prompt( context = build_context_prompt(
game=game, game=game,
board_state=board_state, # Use provided board_state parameter board_state=board_state, # Use provided board_state parameter
power_name=power_name, power_name=power_name,
possible_orders=possible_orders, # Pass possible_orders possible_orders=possible_orders, # Pass possible_orders
game_history=game_history, # Pass game_history game_history=game_history, # Pass game_history
agent_goals=self.goals, agent_goals=self.goals,
agent_relationships=self.relationships, agent_relationships=self.relationships,
agent_private_diary=formatted_diary, # Pass formatted diary agent_private_diary=formatted_diary, # Pass formatted diary
prompts_dir=self.prompts_dir, prompts_dir=self.prompts_dir,
) )
@ -944,10 +966,10 @@ class DiplomacyAgent:
# Create a readable board state string from the board_state dict # Create a readable board state string from the board_state dict
board_state_str = f"Board State:\n" board_state_str = f"Board State:\n"
for p_name, power_data in board_state.get('powers', {}).items(): for p_name, power_data in board_state.get("powers", {}).items():
# Get units and centers from the board state # Get units and centers from the board state
units = power_data.get('units', []) units = power_data.get("units", [])
centers = power_data.get('centers', []) centers = power_data.get("centers", [])
board_state_str += f" {p_name}: Units={units}, Centers={centers}\n" board_state_str += f" {p_name}: Units={units}, Centers={centers}\n"
# Extract year from the phase name (e.g., "S1901M" -> "1901") # Extract year from the phase name (e.g., "S1901M" -> "1901")
@ -956,12 +978,12 @@ class DiplomacyAgent:
prompt = prompt_template.format( prompt = prompt_template.format(
power_name=power_name, power_name=power_name,
current_year=current_year, current_year=current_year,
current_phase=last_phase_name, # Analyze the phase that just ended current_phase=last_phase_name, # Analyze the phase that just ended
board_state_str=board_state_str, board_state_str=board_state_str,
phase_summary=last_phase_summary, # Use provided phase_summary phase_summary=last_phase_summary, # Use provided phase_summary
other_powers=str(other_powers), # Pass as string representation other_powers=str(other_powers), # Pass as string representation
current_goals="\n".join([f"- {g}" for g in self.goals]) if self.goals else "None", current_goals="\n".join([f"- {g}" for g in self.goals]) if self.goals else "None",
current_relationships=str(self.relationships) if self.relationships else "None" current_relationships=str(self.relationships) if self.relationships else "None",
) )
logger.debug(f"[{power_name}] State update prompt:\n{prompt}") logger.debug(f"[{power_name}] State update prompt:\n{prompt}")
@ -973,25 +995,21 @@ class DiplomacyAgent:
log_file_path=log_file_path, log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=current_phase, phase=current_phase,
response_type='state_update', response_type="state_update",
) )
logger.debug(f"[{power_name}] Raw LLM response for state update: {response}") logger.debug(f"[{power_name}] Raw LLM response for state update: {response}")
log_entry_response_type = 'state_update' # Default for log_llm_response log_entry_response_type = "state_update" # Default for log_llm_response
log_entry_success = "FALSE" # Default log_entry_success = "FALSE" # Default
update_data = None # Initialize update_data = None # Initialize
if response is not None and response.strip(): # Check if response is not None and not just whitespace if response is not None and response.strip(): # Check if response is not None and not just whitespace
try: try:
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS # Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1": if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
# Format the natural language response into JSON # Format the natural language response into JSON
formatted_response = await format_with_gemini_flash( formatted_response = await format_with_gemini_flash(
response, response, FORMAT_STATE_UPDATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
FORMAT_STATE_UPDATE,
power_name=power_name,
phase=current_phase,
log_file_path=log_file_path
) )
else: else:
# Use the raw response directly (already formatted) # Use the raw response directly (already formatted)
@ -1007,29 +1025,31 @@ class DiplomacyAgent:
# Check if essential data ('updated_goals' or 'goals') is present AND is a list (for goals) # Check if essential data ('updated_goals' or 'goals') is present AND is a list (for goals)
# For relationships, check for 'updated_relationships' or 'relationships' AND is a dict. # For relationships, check for 'updated_relationships' or 'relationships' AND is a dict.
# Consider it TRUE if at least one of the primary data structures (goals or relationships) is present and correctly typed. # Consider it TRUE if at least one of the primary data structures (goals or relationships) is present and correctly typed.
goals_present_and_valid = isinstance(update_data.get('updated_goals'), list) or isinstance(update_data.get('goals'), list) goals_present_and_valid = isinstance(update_data.get("updated_goals"), list) or isinstance(update_data.get("goals"), list)
rels_present_and_valid = isinstance(update_data.get('updated_relationships'), dict) or isinstance(update_data.get('relationships'), dict) rels_present_and_valid = isinstance(update_data.get("updated_relationships"), dict) or isinstance(
update_data.get("relationships"), dict
)
if update_data and (goals_present_and_valid or rels_present_and_valid): if update_data and (goals_present_and_valid or rels_present_and_valid):
log_entry_success = "TRUE" log_entry_success = "TRUE"
elif update_data: # Parsed, but maybe not all essential data there or not correctly typed elif update_data: # Parsed, but maybe not all essential data there or not correctly typed
log_entry_success = "PARTIAL" log_entry_success = "PARTIAL"
log_entry_response_type = 'state_update_partial_data' log_entry_response_type = "state_update_partial_data"
else: # Parsed to None or empty dict/list, or data not in expected format else: # Parsed to None or empty dict/list, or data not in expected format
log_entry_success = "FALSE" log_entry_success = "FALSE"
log_entry_response_type = 'state_update_parsing_empty_or_invalid_data' log_entry_response_type = "state_update_parsing_empty_or_invalid_data"
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"[{power_name}] Failed to parse JSON response for state update: {e}. Raw response: {response}") logger.error(f"[{power_name}] Failed to parse JSON response for state update: {e}. Raw response: {response}")
log_entry_response_type = 'state_update_json_error' log_entry_response_type = "state_update_json_error"
# log_entry_success remains "FALSE" # log_entry_success remains "FALSE"
except Exception as e: except Exception as e:
logger.error(f"[{power_name}] Unexpected error parsing state update: {e}") logger.error(f"[{power_name}] Unexpected error parsing state update: {e}")
log_entry_response_type = 'state_update_unexpected_error' log_entry_response_type = "state_update_unexpected_error"
update_data = {} update_data = {}
# log_entry_success remains "FALSE" # log_entry_success remains "FALSE"
else: # response was None or empty/whitespace else: # response was None or empty/whitespace
logger.error(f"[{power_name}] No valid response (None or empty) received from LLM for state update.") logger.error(f"[{power_name}] No valid response (None or empty) received from LLM for state update.")
log_entry_response_type = 'state_update_no_response' log_entry_response_type = "state_update_no_response"
# log_entry_success remains "FALSE" # log_entry_success remains "FALSE"
# Log the attempt and its outcome # Log the attempt and its outcome
@ -1039,31 +1059,38 @@ class DiplomacyAgent:
power_name=power_name, power_name=power_name,
phase=current_phase, phase=current_phase,
response_type=log_entry_response_type, response_type=log_entry_response_type,
raw_input_prompt=prompt, # ENSURED raw_input_prompt=prompt, # ENSURED
raw_response=response if response is not None else "", # Handle if response is None raw_response=response if response is not None else "", # Handle if response is None
success=log_entry_success success=log_entry_success,
) )
# Fallback logic if update_data is still None or not usable # Fallback logic if update_data is still None or not usable
if not update_data or not (isinstance(update_data.get('updated_goals'), list) or isinstance(update_data.get('goals'), list) or isinstance(update_data.get('updated_relationships'), dict) or isinstance(update_data.get('relationships'), dict)): if not update_data or not (
logger.warning(f"[{power_name}] update_data is None or missing essential valid structures after LLM call. Using existing goals and relationships as fallback.") isinstance(update_data.get("updated_goals"), list)
update_data = { or isinstance(update_data.get("goals"), list)
or isinstance(update_data.get("updated_relationships"), dict)
or isinstance(update_data.get("relationships"), dict)
):
logger.warning(
f"[{power_name}] update_data is None or missing essential valid structures after LLM call. Using existing goals and relationships as fallback."
)
update_data = {
"updated_goals": self.goals, "updated_goals": self.goals,
"updated_relationships": self.relationships, "updated_relationships": self.relationships,
} }
logger.warning(f"[{power_name}] Using existing goals and relationships as fallback: {update_data}") logger.warning(f"[{power_name}] Using existing goals and relationships as fallback: {update_data}")
# Check for both possible key names (prompt uses "goals"/"relationships", # Check for both possible key names (prompt uses "goals"/"relationships",
# but code was expecting "updated_goals"/"updated_relationships") # but code was expecting "updated_goals"/"updated_relationships")
updated_goals = update_data.get('updated_goals') updated_goals = update_data.get("updated_goals")
if updated_goals is None: if updated_goals is None:
updated_goals = update_data.get('goals') updated_goals = update_data.get("goals")
if updated_goals is not None: if updated_goals is not None:
logger.debug(f"[{power_name}] Using 'goals' key instead of 'updated_goals'") logger.debug(f"[{power_name}] Using 'goals' key instead of 'updated_goals'")
updated_relationships = update_data.get('updated_relationships') updated_relationships = update_data.get("updated_relationships")
if updated_relationships is None: if updated_relationships is None:
updated_relationships = update_data.get('relationships') updated_relationships = update_data.get("relationships")
if updated_relationships is not None: if updated_relationships is not None:
logger.debug(f"[{power_name}] Using 'relationships' key instead of 'updated_relationships'") logger.debug(f"[{power_name}] Using 'relationships' key instead of 'updated_relationships'")
@ -1104,12 +1131,14 @@ class DiplomacyAgent:
# Update relationships if the dictionary is not empty after validation # Update relationships if the dictionary is not empty after validation
if valid_new_relationships: if valid_new_relationships:
self.relationships.update(valid_new_relationships) self.relationships.update(valid_new_relationships)
self.add_journal_entry(f"[{game.current_short_phase}] Relationships updated based on {last_phase_name}: {valid_new_relationships}") self.add_journal_entry(
elif updated_relationships: # Log if the original dict wasn't empty but validation removed everything f"[{game.current_short_phase}] Relationships updated based on {last_phase_name}: {valid_new_relationships}"
)
elif updated_relationships: # Log if the original dict wasn't empty but validation removed everything
logger.warning(f"[{power_name}] Found relationships in LLM response but none were valid after normalization. Using defaults.") logger.warning(f"[{power_name}] Found relationships in LLM response but none were valid after normalization. Using defaults.")
else: # Log if the original dict was empty else: # Log if the original dict was empty
logger.warning(f"[{power_name}] LLM did not provide valid 'updated_relationships' dict in state update.") logger.warning(f"[{power_name}] LLM did not provide valid 'updated_relationships' dict in state update.")
# Keep current relationships, no update needed # Keep current relationships, no update needed
except FileNotFoundError: except FileNotFoundError:
logger.error(f"[{power_name}] state_update_prompt.txt not found. Skipping state update.") logger.error(f"[{power_name}] state_update_prompt.txt not found. Skipping state update.")
@ -1128,11 +1157,11 @@ class DiplomacyAgent:
def update_relationship(self, other_power: str, status: str): def update_relationship(self, other_power: str, status: str):
"""Updates the agent's perceived relationship with another power.""" """Updates the agent's perceived relationship with another power."""
if other_power != self.power_name: if other_power != self.power_name:
self.relationships[other_power] = status self.relationships[other_power] = status
self.add_journal_entry(f"Relationship with {other_power} updated to {status}.") self.add_journal_entry(f"Relationship with {other_power} updated to {status}.")
logger.info(f"[{self.power_name}] Relationship with {other_power} set to {status}.") logger.info(f"[{self.power_name}] Relationship with {other_power} set to {status}.")
else: else:
logger.warning(f"[{self.power_name}] Attempted to set relationship with self.") logger.warning(f"[{self.power_name}] Attempted to set relationship with self.")
def get_agent_state_summary(self) -> str: def get_agent_state_summary(self) -> str:
"""Returns a string summary of the agent's current state.""" """Returns a string summary of the agent's current state."""
@ -1145,7 +1174,7 @@ class DiplomacyAgent:
# summary += f"\n Last Journal Entry: {self.private_journal[-1]}" # summary += f"\n Last Journal Entry: {self.private_journal[-1]}"
return summary return summary
def generate_plan(self, game: Game, board_state: dict, game_history: 'GameHistory') -> str: def generate_plan(self, game: Game, board_state: dict, game_history: "GameHistory") -> str:
"""Generates a strategic plan using the client and logs it.""" """Generates a strategic plan using the client and logs it."""
logger.info(f"Agent {self.power_name} generating strategic plan...") logger.info(f"Agent {self.power_name} generating strategic plan...")
try: try:
@ -1157,3 +1186,4 @@ class DiplomacyAgent:
logger.error(f"Agent {self.power_name} failed to generate plan: {e}") logger.error(f"Agent {self.power_name} failed to generate plan: {e}")
self.add_journal_entry(f"Failed to generate plan for phase {game.current_phase} due to error: {e}") self.add_journal_entry(f"Failed to generate plan for phase {game.current_phase} due to error: {e}")
return "Error: Failed to generate plan." return "Error: Failed to generate plan."

View file

@ -29,6 +29,7 @@ from collections import defaultdict, Counter
import re import re
from typing import Dict, List, Tuple, Optional, Any from typing import Dict, List, Tuple, Optional, Any
import statistics import statistics
from ..models import PowerEnum
class StatisticalGameAnalyzer: class StatisticalGameAnalyzer:
"""Production-ready analyzer for AI Diplomacy game statistics. """Production-ready analyzer for AI Diplomacy game statistics.
@ -47,7 +48,6 @@ class StatisticalGameAnalyzer:
'Ally': 2 'Ally': 2
} }
DIPLOMACY_POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
# Complete list of response types found in actual data # Complete list of response types found in actual data
RESPONSE_TYPES = [ RESPONSE_TYPES = [
@ -59,44 +59,8 @@ class StatisticalGameAnalyzer:
def __init__(self): def __init__(self):
"""Initialize analyzer with configuration constants.""" """Initialize analyzer with configuration constants."""
self.relationship_values = self.RELATIONSHIP_VALUES self.relationship_values = self.RELATIONSHIP_VALUES
self.powers = self.DIPLOMACY_POWERS
def _normalize_recipient_name(self, recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
'EGMANY': 'GERMANY',
'GERMAN': 'GERMANY',
'UK': 'ENGLAND',
'BRIT': 'ENGLAND',
'ENGLAND': 'ENGLAND', # Keep as-is
'FRANCE': 'FRANCE', # Keep as-is
'GERMANY': 'GERMANY', # Keep as-is
'ITALY': 'ITALY', # Keep as-is
'AUSTRIA': 'AUSTRIA', # Keep as-is
'RUSSIA': 'RUSSIA', # Keep as-is
'TURKEY': 'TURKEY', # Keep as-is
'Germany': 'GERMANY',
'England': 'ENGLAND',
'France': 'FRANCE',
'Italy': 'ITALY',
'Russia': 'RUSSIA',
'Austria': 'AUSTRIA',
'Turkey': 'TURKEY',
}
normalized = name_mapping.get(recipient, recipient)
# Validate it's a known power
if normalized not in self.DIPLOMACY_POWERS:
return None # Invalid recipient
return normalized
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]: def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
""" """
@ -272,11 +236,11 @@ class StatisticalGameAnalyzer:
if not phase_data: if not phase_data:
continue continue
for power in self.powers: for power in PowerEnum:
for response_type in response_types: for response_type in response_types:
# Extract features for this specific power/phase/response_type combination # Extract features for this specific power/phase/response_type combination
features = self._extract_power_phase_response_features( features = self._extract_power_phase_response_features(
power, phase_name, response_type, llm_responses, phase_data, game_data power.value, phase_name, response_type, llm_responses, phase_data, game_data
) )
if features: if features:
phase_features.append(features) phase_features.append(features)
@ -530,7 +494,7 @@ class StatisticalGameAnalyzer:
game_features = [] game_features = []
game_scores = self._compute_game_scores(game_data) game_scores = self._compute_game_scores(game_data)
for power in self.powers: for power in PowerEnum:
features = { features = {
# === IDENTIFIERS === # === IDENTIFIERS ===
'game_id': game_data.get('id', 'unknown'), 'game_id': game_data.get('id', 'unknown'),
@ -718,7 +682,8 @@ class StatisticalGameAnalyzer:
# Categorize by relationship # Categorize by relationship
recipient = msg.get('recipient_power') recipient = msg.get('recipient_power')
normalized_recipient = self._normalize_recipient_name(recipient) # This will coerce some known aliases to match the 7 acceptable names (see models.py)
normalized_recipient = PowerEnum(recipient)
# Skip self-messages and invalid recipients # Skip self-messages and invalid recipients
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships: if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
@ -969,7 +934,7 @@ class StatisticalGameAnalyzer:
break break
# elimination turn for every power # elimination turn for every power
elim_turn: dict[str, int | None] = {p: None for p in self.DIPLOMACY_POWERS} elim_turn: dict[str, int | None] = {p: None for p in [power.value for power in PowerEnum]}
for idx, ph in enumerate(phases): for idx, ph in enumerate(phases):
yr = self._phase_year(phases, idx) yr = self._phase_year(phases, idx)
if yr is None: if yr is None:

48
bot_client/config.py Normal file
View file

@ -0,0 +1,48 @@
import datetime
from pydantic_settings import BaseSettings
from pathlib import Path
import warnings
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path | None = None
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None
def __init__(self, power_name, **kwargs):
super().__init__(**kwargs)
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
# Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
self.log_file_path.touch(exist_ok=True)
self._validate_api_keys()
def _validate_api_keys(self):
"""Validate API keys at startup and issue warnings for missing keys"""
api_keys = [
"DEEPSEEK_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
"OPENROUTER_API_KEY",
]
for key in api_keys:
value = super().__getattribute__(key)
if not value or (isinstance(value, str) and len(value) == 0):
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
def __getattribute__(self, name):
"""Override to check for empty API keys at access time"""
value = super().__getattribute__(name)
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
return value

60
config.py Normal file
View file

@ -0,0 +1,60 @@
import datetime
from typing import Optional
from pydantic_settings import BaseSettings
from pathlib import Path
import warnings
from models import PowerEnum
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path | None = None
USE_UNFORMATTED_PROMPTS: bool = False
# API Keys to be validated. Warns if they aren't present at startup, raises ValueError if you attempt to use them when they aren't present.
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None
def __init__(self, power_name: Optional[PowerEnum] = None, **kwargs):
super().__init__(**kwargs)
# Add a '-POWER' to the end of the file name if it's for a specific power
log_power_path = "-" + power_name if power_name else None
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/logs{log_power_path} .txt")
# Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
self.log_file_path.touch(exist_ok=True)
self._validate_api_keys()
def _validate_api_keys(self):
"""Validate API keys at startup and issue warnings for missing keys"""
api_keys = [
"DEEPSEEK_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
"OPENROUTER_API_KEY",
]
for key in api_keys:
value = super().__getattribute__(key)
if not value or (isinstance(value, str) and len(value) == 0):
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
def __getattribute__(self, name):
"""Override to check for empty API keys at access time"""
value = super().__getattribute__(name)
# If this is a _KEY, it must be not None, string, and length > 0 to return. We do not validate the correctness of the key.
# e.g. "thisIsAKey" is valid in this sense.
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
return value
config = Configuration()

38
models.py Normal file
View file

@ -0,0 +1,38 @@
from enum import Enum
from typing import Any, Optional
# your “typo → canonical” map
_POWER_ALIASES = {
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
}
class PowerEnum(str, Enum):
AUSTRIA = "AUSTRIA"
ENGLAND = "ENGLAND"
FRANCE = "FRANCE"
GERMANY = "GERMANY"
ITALY = "ITALY"
RUSSIA = "RUSSIA"
TURKEY = "TURKEY"
@classmethod
def _missing_(cls, value: Any) -> Optional["Enum"]:
"""
Called when you do PowerEnum(value) and `value` isn't one of the raw enum values.
Here we normalize strings to upperstripped, apply aliases, then retry.
"""
if isinstance(value, str):
normalized = value.upper().strip()
# apply any synonyms/typos
normalized = _POWER_ALIASES.get(normalized, normalized)
# look up in the normal value→member map
member = cls._value2member_map_.get(normalized)
if member is not None:
return member
# by default, let Enum raise the ValueError
return super()._missing_(value)

View file

@ -5,28 +5,32 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"anthropic>=0.54.0", "anthropic>=0.54.0",
"bcrypt>=4.3.0", "bcrypt>=4.3.0",
"coloredlogs>=15.0.1", "coloredlogs>=15.0.1",
"dotenv>=0.9.9", "dotenv>=0.9.9",
"google-genai>=1.21.1", "google-genai>=1.21.1",
"google-generativeai>=0.8.5", "google-generativeai>=0.8.5",
"json-repair>=0.47.2", "json-repair>=0.47.2",
"json5>=0.12.0", "json5>=0.12.0",
"matplotlib>=3.10.3", "matplotlib>=3.10.3",
"openai>=1.90.0", "openai>=1.90.0",
"pylint>=2.3.0", "pydantic-settings>=2.10.1",
"pytest>=4.4.0", "pylint>=2.3.0",
"pytest-xdist>=3.7.0", "pytest>=4.4.0",
"python-dateutil>=2.9.0.post0", "pytest-xdist>=3.7.0",
"pytz>=2025.2", "python-dateutil>=2.9.0.post0",
"scipy>=1.16.0", "pytz>=2025.2",
"seaborn>=0.13.2", "scipy>=1.16.0",
"sphinx>=8.2.3", "seaborn>=0.13.2",
"sphinx-copybutton>=0.5.2", "sphinx>=8.2.3",
"sphinx-rtd-theme>=3.0.2", "sphinx-copybutton>=0.5.2",
"together>=1.5.17", "sphinx-rtd-theme>=3.0.2",
"tornado>=5.0", "together>=1.5.17",
"tqdm>=4.67.1", "tornado>=5.0",
"ujson>=5.10.0", "tqdm>=4.67.1",
"ujson>=5.10.0",
] ]
[tool.ruff]
line-length = 150

21
tests/test_models.py Normal file
View file

@ -0,0 +1,21 @@
from models import PowerEnum
def test_power_name_aliases():
# Test all aliases defined in _POWER_ALIASES
assert PowerEnum("UK") == PowerEnum.ENGLAND
assert PowerEnum("BRIT") == PowerEnum.ENGLAND
assert PowerEnum("EGMANY") == PowerEnum.GERMANY
assert PowerEnum("GERMAN") == PowerEnum.GERMANY
# Test direct enum values (no alias needed)
assert PowerEnum("AUSTRIA") == PowerEnum.AUSTRIA
assert PowerEnum("FRANCE") == PowerEnum.FRANCE
# Test case insensitivity
assert PowerEnum("france") == PowerEnum.FRANCE
assert PowerEnum("iTaLy") == PowerEnum.ITALY
# Test with whitespace
assert PowerEnum(" RUSSIA ") == PowerEnum.RUSSIA
assert PowerEnum("TURKEY ") == PowerEnum.TURKEY

1373
uv.lock generated

File diff suppressed because it is too large Load diff