diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index 3110454..f165d89 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -10,7 +10,7 @@ import ast from ..config import config # Assuming BaseModelClient is importable from clients.py in the same directory -from .clients import BaseModelClient, load_model_client +from .clients import BaseModelClient # Import load_prompt and the new logging wrapper from utils from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path @@ -336,7 +336,7 @@ class DiplomacyAgent: ) # If even the last resort doesn't give a dict, return empty dict return {} - except Exception as e: + except Exception: logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...") return {} @@ -965,7 +965,7 @@ class DiplomacyAgent: other_powers = [p for p in game.powers if p != power_name] # Create a readable board state string from the board_state dict - board_state_str = f"Board State:\n" + board_state_str = "Board State:\n" for p_name, power_data in board_state.get("powers", {}).items(): # Get units and centers from the board state units = power_data.get("units", []) diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index fc44a73..07e5a43 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -1,41 +1,41 @@ import os import json -from json import JSONDecodeError import re import logging import ast # For literal_eval in JSON fallback parsing import aiohttp # For direct HTTP requests to Responses API -from typing import List, Dict, Optional, Any, Tuple, NamedTuple +from typing import List, Dict, Optional, Tuple, NamedTuple from dotenv import load_dotenv # Use Async versions of clients from openai import AsyncOpenAI -from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity +from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity from anthropic import AsyncAnthropic import asyncio import requests import google.generativeai as genai from together import AsyncTogether -from together.error import APIError as TogetherAPIError # For specific error handling +from together.error import APIError as TogetherAPIError # For specific error handling from ..config import config -from diplomacy.engine.message import GLOBAL from .game_history import GameHistory from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path + # Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible from .prompt_constructor import construct_order_generation_prompt, build_context_prompt # Moved formatter imports to avoid circular import - imported locally where needed # set logger back to just info logger = logging.getLogger("client") -logger.setLevel(logging.DEBUG) # Keep debug for now during async changes +logger.setLevel(logging.DEBUG) # Keep debug for now during async changes # Note: BasicConfig might conflict if already configured in lm_game. Keep client-specific for now. # logging.basicConfig(level=logging.DEBUG) # Might be redundant if lm_game configures root load_dotenv() + ############################################################################## # 1) Base Interface ############################################################################## @@ -52,7 +52,7 @@ class BaseModelClient: self.model_name = model_name self.prompts_dir = prompts_dir # Load a default initially, can be overwritten by set_system_prompt - self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir) + self.system_prompt = load_prompt("system_prompt.txt", prompts_dir=self.prompts_dir) self.max_tokens = 16000 # default unless overridden def set_system_prompt(self, content: str): @@ -76,13 +76,13 @@ class BaseModelClient: board_state, power_name: str, possible_orders: Dict[str, List[str]], - conversation_text: str, # This is GameHistory + conversation_text: str, # This is GameHistory model_error_stats: dict, log_file_path: str, phase: str, agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added + agent_private_diary_str: Optional[str] = None, # Added ) -> List[str]: """ 1) Builds the prompt with conversation context if available @@ -90,7 +90,7 @@ class BaseModelClient: 3) Parses JSON block """ # The 'conversation_text' parameter was GameHistory. Renaming for clarity. - game_history_obj = conversation_text + game_history_obj = conversation_text prompt = construct_order_generation_prompt( system_prompt=self.system_prompt, @@ -98,7 +98,7 @@ class BaseModelClient: board_state=board_state, power_name=power_name, possible_orders=possible_orders, - game_history=game_history_obj, # Pass GameHistory object + game_history=game_history_obj, # Pass GameHistory object agent_goals=agent_goals, agent_relationships=agent_relationships, agent_private_diary_str=agent_private_diary_str, @@ -108,7 +108,7 @@ class BaseModelClient: raw_response = "" # Initialize success status. Will be updated based on outcome. success_status = "Failure: Initialized" - parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback + parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback try: # Call LLM using the logging wrapper @@ -118,36 +118,29 @@ class BaseModelClient: log_file_path=log_file_path, power_name=power_name, phase=phase, - response_type='order', # Context for run_llm_and_log's own error logging - temperature=0 - ) - logger.debug( - f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}" + response_type="order", # Context for run_llm_and_log's own error logging + temperature=0, ) + logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}") # Conditionally format the response based on USE_UNFORMATTED_PROMPTS if config.USE_UNFORMATTED_PROMPTS: # Local import to avoid circular dependency from .formatter import format_with_gemini_flash, FORMAT_ORDERS + # Format the natural language response into structured format formatted_response = await format_with_gemini_flash( - raw_response, - FORMAT_ORDERS, - power_name=power_name, - phase=phase, - log_file_path=log_file_path + raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path ) else: # Use the raw response directly (already formatted) formatted_response = raw_response - + # Attempt to parse the final "orders" from the formatted response move_list = self._extract_moves(formatted_response, power_name) if not move_list: - logger.warning( - f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback." - ) + logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.") if model_error_stats is not None and self.model_name in model_error_stats: model_error_stats[self.model_name].setdefault("order_decoding_errors", 0) model_error_stats[self.model_name]["order_decoding_errors"] += 1 @@ -163,18 +156,18 @@ class BaseModelClient: max_invalid_to_log = 5 display_invalid_moves = invalid_moves_list[:max_invalid_to_log] omitted_count = len(invalid_moves_list) - len(display_invalid_moves) - + invalid_moves_str = ", ".join(display_invalid_moves) if omitted_count > 0: invalid_moves_str += f", ... ({omitted_count} more)" - + success_status = f"Failure: Invalid LLM Moves ({len(invalid_moves_list)}): {invalid_moves_str}" # If some moves were validated despite others being invalid, it's still not a full 'Success' # because the LLM didn't provide a fully usable set of orders without intervention/fallbacks. # The fallback_orders logic within _validate_orders might fill in missing pieces, # but the key is that the LLM *proposed* invalid moves. - if not validated_moves: # All LLM moves were invalid - logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}") + if not validated_moves: # All LLM moves were invalid + logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}") else: logger.info(f"[{power_name}] Some LLM-proposed moves were invalid. Using fallbacks/validated. Invalid: {invalid_moves_list}") else: @@ -186,16 +179,16 @@ class BaseModelClient: # Fallback is already set to parsed_orders_for_return finally: # Log the attempt regardless of outcome - if log_file_path: # Only log if a path is provided + if log_file_path: # Only log if a path is provided log_llm_response( log_file_path=log_file_path, model_name=self.model_name, power_name=power_name, phase=phase, - response_type="order_generation", # Specific type for CSV logging - raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg + response_type="order_generation", # Specific type for CSV logging + raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg raw_response=raw_response, - success=success_status + success=success_status, # token_usage and cost can be added later if available and if log_llm_response supports them ) return parsed_orders_for_return @@ -216,57 +209,43 @@ class BaseModelClient: if not matches: # Some LLMs might not put the colon or might have triple backtick fences. - logger.debug( - f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns." - ) + logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.") # 1b) Check for inline JSON after "PARSABLE OUTPUT" pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$" matches = re.search(pattern_alt, raw_response, re.DOTALL) - + if not matches: # 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks) - logger.debug( - f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern." - ) + logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.") pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})" matches = re.search(pattern_asterisk, raw_response, re.DOTALL) if not matches: - logger.debug( - f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences." - ) + logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.") # 2) If still no match, check for triple-backtick code fences containing JSON if not matches: code_fence_pattern = r"```json\n(.*?)\n```" matches = re.search(code_fence_pattern, raw_response, re.DOTALL) if matches: - logger.debug( - f"[{self.model_name}] Found triple-backtick JSON block for {power_name}." - ) - + logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.") + # 2b) Also try plain ``` code fences without json marker if not matches: code_fence_plain = r"```\n(.*?)\n```" matches = re.search(code_fence_plain, raw_response, re.DOTALL) if matches: - logger.debug( - f"[{self.model_name}] Found plain triple-backtick block for {power_name}." - ) - + logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.") + # 2c) Try to find bare JSON object anywhere in the response if not matches: - logger.debug( - f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON." - ) + logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.") # Look for a JSON object that contains "orders" key bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})' matches = re.search(bare_json_pattern, raw_response, re.DOTALL) if matches: - logger.debug( - f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}." - ) + logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.") # 3) Attempt to parse JSON if we found anything json_text = None @@ -283,9 +262,7 @@ class BaseModelClient: json_text = json_text.strip() if not json_text: - logger.debug( - f"[{self.model_name}] No JSON text found in LLM response for {power_name}." - ) + logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.") return None # 3a) Try JSON loading @@ -293,14 +270,12 @@ class BaseModelClient: data = json.loads(json_text) return data.get("orders", None) except json.JSONDecodeError as e: - logger.warning( - f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues." - ) - + logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.") + # Try to fix common JSON issues try: # Remove trailing commas - fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text) + fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text) # Fix single quotes to double quotes fixed_json = fixed_json.replace("'", '"') # Try parsing again @@ -308,14 +283,12 @@ class BaseModelClient: logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}") return data.get("orders", None) except json.JSONDecodeError: - logger.warning( - f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments." - ) - + logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.") + # Try to remove inline comments (// style) try: # Remove // comments from each line - lines = json_text.split('\n') + lines = json_text.split("\n") cleaned_lines = [] for line in lines: # Find // that's not inside quotes @@ -326,33 +299,31 @@ class BaseModelClient: if escape_next: escape_next = False continue - if char == '\\': + if char == "\\": escape_next = True continue if char == '"' and not escape_next: in_quotes = not in_quotes - if not in_quotes and line[i:i+2] == '//': + if not in_quotes and line[i : i + 2] == "//": comment_pos = i break - + if comment_pos >= 0: # Remove comment but keep any trailing comma cleaned_line = line[:comment_pos].rstrip() else: cleaned_line = line cleaned_lines.append(cleaned_line) - - comment_free_json = '\n'.join(cleaned_lines) + + comment_free_json = "\n".join(cleaned_lines) # Also remove trailing commas after comment removal - comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json) - + comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json) + data = json.loads(comment_free_json) logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}") return data.get("orders", None) except json.JSONDecodeError: - logger.warning( - f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback." - ) + logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.") # 3b) Attempt bracket fallback: we look for the substring after "orders" # E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors. @@ -366,30 +337,26 @@ class BaseModelClient: if isinstance(moves, list): return moves except Exception as e2: - logger.warning( - f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}" - ) + logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}") # If all attempts failed return None - - def _validate_orders( - self, moves: List[str], possible_orders: Dict[str, List[str]] - ) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE + + def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE """ Filter out invalid moves, fill missing with HOLD, else fallback. Returns a tuple: (validated_moves, invalid_moves_found) """ logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}") validated = [] - invalid_moves_found = [] # ADDED: To collect invalid moves + invalid_moves_found = [] # ADDED: To collect invalid moves used_locs = set() if not isinstance(moves, list): logger.debug(f"[{self.model_name}] Moves not a list, fallback.") # Return fallback and empty list for invalid_moves_found as no specific LLM moves were processed - return self.fallback_orders(possible_orders), [] - + return self.fallback_orders(possible_orders), [] + for move_str in moves: # Check if it's in possible orders if any(move_str in loc_orders for loc_orders in possible_orders.values()): @@ -399,21 +366,21 @@ class BaseModelClient: used_locs.add(parts[1][:3]) else: logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}") - invalid_moves_found.append(move_str) # ADDED: Collect invalid move + invalid_moves_found.append(move_str) # ADDED: Collect invalid move # Fill missing with hold for loc, orders_list in possible_orders.items(): if loc not in used_locs and orders_list: hold_candidates = [o for o in orders_list if o.endswith("H")] - validated.append( - hold_candidates[0] if hold_candidates else orders_list[0] - ) + validated.append(hold_candidates[0] if hold_candidates else orders_list[0]) - if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM) + if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM) logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.") return self.fallback_orders(possible_orders), [] - elif not validated and invalid_moves_found: # All LLM moves were invalid - logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}") + elif not validated and invalid_moves_found: # All LLM moves were invalid + logger.warning( + f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}" + ) # We return empty list for validated, but the invalid_moves_found list is populated return self.fallback_orders(possible_orders), invalid_moves_found @@ -442,9 +409,8 @@ class BaseModelClient: # log_file_path: str, # Not used directly by build_context_prompt agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added + agent_private_diary_str: Optional[str] = None, # Added ) -> str: - instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir) context = self.build_context_prompt( @@ -455,7 +421,7 @@ class BaseModelClient: game_history, agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary=agent_private_diary_str, # Pass diary string + agent_private_diary=agent_private_diary_str, # Pass diary string prompts_dir=self.prompts_dir, ) @@ -472,7 +438,7 @@ class BaseModelClient: # log_file_path: str, # Not used directly by build_context_prompt agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added + agent_private_diary_str: Optional[str] = None, # Added ) -> str: # MINIMAL CHANGE: Just change to load unformatted version conditionally instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir) @@ -486,19 +452,19 @@ class BaseModelClient: game_history, agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary=agent_private_diary_str, # Pass diary string + agent_private_diary=agent_private_diary_str, # Pass diary string prompts_dir=self.prompts_dir, ) - + # KEEP ORIGINAL: Get recent messages targeting this power to prioritize responses recent_messages_to_power = game_history.get_recent_messages_to_power(power_name, limit=3) - + # KEEP ORIGINAL: Debug logging to verify messages logger.info(f"[{power_name}] Found {len(recent_messages_to_power)} high priority messages to respond to") if recent_messages_to_power: for i, msg in enumerate(recent_messages_to_power): - logger.info(f"[{power_name}] Priority message {i+1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...") - + logger.info(f"[{power_name}] Priority message {i + 1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...") + # KEEP ORIGINAL: Add a section for unanswered messages unanswered_messages = "\n\nRECENT MESSAGES REQUIRING YOUR ATTENTION:\n" if recent_messages_to_power: @@ -506,25 +472,32 @@ class BaseModelClient: unanswered_messages += f"\nFrom {msg['sender']} in {msg['phase']}: {msg['content']}\n" else: unanswered_messages += "\nNo urgent messages requiring direct responses.\n" - + final_prompt = context + unanswered_messages + "\n\n" + instructions - final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey') + final_prompt = ( + final_prompt.replace("AUSTRIA", "Austria") + .replace("ENGLAND", "England") + .replace("FRANCE", "France") + .replace("GERMANY", "Germany") + .replace("ITALY", "Italy") + .replace("RUSSIA", "Russia") + .replace("TURKEY", "Turkey") + ) return final_prompt - async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py + async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py self, game, board_state, power_name: str, possible_orders: Dict[str, List[str]], game_history: GameHistory, - game_phase: str, # Used for logging - log_file_path: str, # Used for logging + game_phase: str, # Used for logging + log_file_path: str, # Used for logging agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added + agent_private_diary_str: Optional[str] = None, # Added ) -> str: - prompt = self.build_planning_prompt( game, board_state, @@ -535,7 +508,7 @@ class BaseModelClient: # log_file_path, # Not passed to build_planning_prompt directly agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary_str=agent_private_diary_str, # Pass diary string + agent_private_diary_str=agent_private_diary_str, # Pass diary string ) # Call LLM using the logging wrapper @@ -544,12 +517,12 @@ class BaseModelClient: prompt=prompt, log_file_path=log_file_path, power_name=power_name, - phase=game_phase, # Use game_phase for logging - response_type='plan_reply', # Changed from 'plan' to avoid confusion + phase=game_phase, # Use game_phase for logging + response_type="plan_reply", # Changed from 'plan' to avoid confusion ) logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}") return raw_response - + async def get_conversation_reply( self, game, @@ -559,18 +532,18 @@ class BaseModelClient: game_history: GameHistory, game_phase: str, log_file_path: str, - active_powers: Optional[List[str]] = None, + active_powers: Optional[List[str]] = None, agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, + agent_private_diary_str: Optional[str] = None, ) -> List[Dict[str, str]]: """ Generates a negotiation message, considering agent state. """ - raw_input_prompt = "" # Initialize for finally block - raw_response = "" # Initialize for finally block - success_status = "Failure: Initialized" # Default status - messages_to_return = [] # Initialize to ensure it's defined + raw_input_prompt = "" # Initialize for finally block + raw_response = "" # Initialize for finally block + success_status = "Failure: Initialized" # Default status + messages_to_return = [] # Initialize to ensure it's defined try: raw_input_prompt = self.build_conversation_prompt( @@ -581,7 +554,7 @@ class BaseModelClient: game_history, agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary_str=agent_private_diary_str, + agent_private_diary_str=agent_private_diary_str, ) logger.debug(f"[{self.model_name}] Conversation prompt for {power_name}:\n{raw_input_prompt}") @@ -591,33 +564,30 @@ class BaseModelClient: prompt=raw_input_prompt, log_file_path=log_file_path, power_name=power_name, - phase=game_phase, - response_type='negotiation', # For run_llm_and_log's internal context + phase=game_phase, + response_type="negotiation", # For run_llm_and_log's internal context ) logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}") - + # Conditionally format the response based on USE_UNFORMATTED_PROMPTS if config.USE_UNFORMATTED_PROMPTS: # Local import to avoid circular dependency from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION + # Format the natural language response into structured JSON formatted_response = await format_with_gemini_flash( - raw_response, - FORMAT_CONVERSATION, - power_name=power_name, - phase=game_phase, - log_file_path=log_file_path + raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path ) else: # Use the raw response directly (already formatted) formatted_response = raw_response - + parsed_messages = [] json_blocks = [] json_decode_error_occurred = False - + # For formatted response, we expect a clean JSON array - try: + try: data = json.loads(formatted_response) if isinstance(data, list): parsed_messages = data @@ -628,14 +598,14 @@ class BaseModelClient: logger.warning(f"[{self.model_name}] Failed to parse formatted response as JSON, falling back to regex") # Fall back to original parsing logic using formatted_response raw_response = formatted_response - + # Original parsing logic as fallback if not parsed_messages: # Attempt to find blocks enclosed in {{...}} - double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL) + double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL) if double_brace_blocks: # If {{...}} blocks are found, assume each is a self-contained JSON object - json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks]) + json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks]) else: # If no {{...}} blocks, look for ```json ... ``` markdown blocks code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL) @@ -650,10 +620,10 @@ class BaseModelClient: json_blocks = [json.dumps(data)] except json.JSONDecodeError: # If parsing the whole block fails, fall back to regex for individual objects - json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL) + json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL) else: # If no markdown block, fall back to regex for any JSON object in the response - json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL) + json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL) # Process json_blocks if we have them from fallback parsing if not parsed_messages and json_blocks: @@ -661,13 +631,13 @@ class BaseModelClient: try: cleaned_block = block.strip() # Attempt to fix common JSON issues like trailing commas before parsing - cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block) + cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block) parsed_message = json.loads(cleaned_block) parsed_messages.append(parsed_message) except json.JSONDecodeError as e: logger.warning(f"[{self.model_name}] Failed to parse JSON block {block_index} for {power_name}: {e}") json_decode_error_occurred = True - + if not parsed_messages: logger.warning(f"[{self.model_name}] No valid messages found in response for {power_name}") success_status = "Success: No messages found" @@ -684,7 +654,7 @@ class BaseModelClient: else: logger.warning(f"[{self.model_name}] Invalid message structure for {power_name}") parsed_messages = validated_messages - + # Set final status and return value if parsed_messages: success_status = "Success: Messages extracted" @@ -695,11 +665,11 @@ class BaseModelClient: logger.debug(f"[{self.model_name}] Validated conversation replies for {power_name}: {messages_to_return}") # return messages_to_return # Return will happen in finally block or after - + except Exception as e: logger.error(f"[{self.model_name}] Error in get_conversation_reply for {power_name}: {e}", exc_info=True) success_status = f"Failure: Exception ({type(e).__name__})" - messages_to_return = [] # Ensure empty list on general exception + messages_to_return = [] # Ensure empty list on general exception finally: if log_file_path: log_llm_response( @@ -710,11 +680,11 @@ class BaseModelClient: response_type="negotiation_message", raw_input_prompt=raw_input_prompt, raw_response=raw_response, - success=success_status + success=success_status, ) return messages_to_return - async def get_plan( # This is the original get_plan, now distinct from get_planning_reply + async def get_plan( # This is the original get_plan, now distinct from get_planning_reply self, game, board_state, @@ -724,14 +694,14 @@ class BaseModelClient: log_file_path: str, agent_goals: Optional[List[str]] = None, agent_relationships: Optional[Dict[str, str]] = None, - agent_private_diary_str: Optional[str] = None, # Added + agent_private_diary_str: Optional[str] = None, # Added ) -> str: """ Generates a strategic plan for the given power based on the current state. This method is called by the agent's generate_plan method. """ logger.info(f"Client generating strategic plan for {power_name}...") - + planning_instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir) if not planning_instructions: logger.error("Could not load planning_instructions.txt! Cannot generate plan.") @@ -740,17 +710,17 @@ class BaseModelClient: # For planning, possible_orders might be less critical for the context, # but build_context_prompt expects it. We can pass an empty dict or calculate it. # For simplicity, let's pass empty if not strictly needed by context for planning. - possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context - + possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context + context_prompt = self.build_context_prompt( game, board_state, power_name, - possible_orders_for_context, + possible_orders_for_context, game_history, agent_goals=agent_goals, agent_relationships=agent_relationships, - agent_private_diary=agent_private_diary_str, # Pass diary string + agent_private_diary=agent_private_diary_str, # Pass diary string prompts_dir=self.prompts_dir, ) @@ -765,12 +735,12 @@ class BaseModelClient: try: # Use run_llm_and_log for the actual LLM call raw_plan_response = await run_llm_and_log( - client=self, # Pass self (the client instance) + client=self, # Pass self (the client instance) prompt=full_prompt, log_file_path=log_file_path, power_name=power_name, - phase=game.current_short_phase, - response_type='plan_generation', # More specific type for run_llm_and_log context + phase=game.current_short_phase, + response_type="plan_generation", # More specific type for run_llm_and_log context ) logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}") # No parsing needed for the plan, return the raw string @@ -781,16 +751,16 @@ class BaseModelClient: success_status = f"Failure: Exception ({type(e).__name__})" plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}" finally: - if log_file_path: # Only log if a path is provided + if log_file_path: # Only log if a path is provided log_llm_response( log_file_path=log_file_path, model_name=self.model_name, power_name=power_name, phase=game.current_short_phase if game else "UnknownPhase", - response_type="plan_generation", # Specific type for CSV logging - raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg + response_type="plan_generation", # Specific type for CSV logging + raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg raw_response=raw_plan_response, - success=success_status + success=success_status, # token_usage and cost can be added later ) return plan_to_return @@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient): ): super().__init__(model_name, prompts_dir=prompts_dir) - self.base_url = ( - base_url - or os.environ.get("OPENAI_BASE_URL") - or "https://api.openai.com/v1" - ) + self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1" self.api_key = api_key or os.environ.get("OPENAI_API_KEY") if not self.api_key: @@ -832,18 +798,14 @@ class OpenAIClient(BaseModelClient): inject_random_seed: bool = True, ) -> str: try: - system_prompt_content = ( - f"{generate_random_seed()}\n\n{self.system_prompt}" - if inject_random_seed - else self.system_prompt - ) + system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:" response = await self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": prompt_with_cta}, + {"role": "user", "content": prompt_with_cta}, ], temperature=temperature, max_tokens=self.max_tokens, @@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient): temperature=temperature, ) if not response.content: - logger.warning( - f"[{self.model_name}] Empty content in Claude generate_response. Returning empty." - ) + logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.") return "" return response.content[0].text.strip() if response.content else "" except json.JSONDecodeError as json_err: - logger.error( - f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}" - ) + logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}") return "" except Exception as e: - logger.error( - f"[{self.model_name}] Unexpected error in generate_response: {e}" - ) + logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") return "" @@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient): full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:" try: - generation_config = genai.types.GenerationConfig( - temperature=temperature, - max_output_tokens=self.max_tokens - ) + generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens) response = await self.client.generate_content_async( contents=full_prompt, generation_config=generation_config, ) - + if not response or not response.text: - logger.warning( - f"[{self.model_name}] Empty Gemini generate_response. Returning empty." - ) + logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.") return "" return response.text.strip() except Exception as e: @@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient): def __init__(self, model_name: str, prompts_dir: Optional[str] = None): super().__init__(model_name, prompts_dir=prompts_dir) self.api_key = os.environ.get("DEEPSEEK_API_KEY") - self.client = AsyncDeepSeekOpenAI( - api_key=self.api_key, - base_url="https://api.deepseek.com/" - ) + self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/") async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: try: @@ -982,13 +930,11 @@ class DeepSeekClient(BaseModelClient): temperature=temperature, max_tokens=self.max_tokens, ) - + logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}") if not response or not response.choices: - logger.warning( - f"[{self.model_name}] No valid response in generate_response." - ) + logger.warning(f"[{self.model_name}] No valid response in generate_response.") return "" content = response.choices[0].message.content.strip() @@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient): return content except Exception as e: - logger.error( - f"[{self.model_name}] Unexpected error in generate_response: {e}" - ) + logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") return "" @@ -1028,7 +972,7 @@ class OpenAIResponsesClient(BaseModelClient): system_prompt_content = f"{random_seed}\n\n{self.system_prompt}" full_prompt = f"{system_prompt_content}\n\n{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:" - + # Prepare the request payload payload = { "model": self.model_name, @@ -1036,79 +980,60 @@ class OpenAIResponsesClient(BaseModelClient): "temperature": temperature, "max_tokens": self.max_tokens, } - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" - } - + + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + # Make the API call using aiohttp async with aiohttp.ClientSession() as session: async with session.post(self.base_url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() - logger.error( - f"[{self.model_name}] API error (status {response.status}): {error_text}" - ) + logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}") return "" - + response_data = await response.json() - + # Extract the text from the nested response structure # The text is in output[1].content[0].text based on the response try: outputs = response_data.get("output", []) if len(outputs) < 2: - logger.warning( - f"[{self.model_name}] Unexpected output structure. Full response: {response_data}" - ) + logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}") return "" - + # The message is typically in the second output item message_output = outputs[1] if message_output.get("type") != "message": - logger.warning( - f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}" - ) + logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}") return "" - + content_list = message_output.get("content", []) if not content_list: - logger.warning( - f"[{self.model_name}] Empty content list in message output" - ) + logger.warning(f"[{self.model_name}] Empty content list in message output") return "" - + # Look for the content item with type 'output_text' text_content = "" for content_item in content_list: if content_item.get("type") == "output_text": text_content = content_item.get("text", "") break - + if not text_content: - logger.warning( - f"[{self.model_name}] No output_text found in content. Full content: {content_list}" - ) + logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}") return "" - + return text_content.strip() - + except (KeyError, IndexError, TypeError) as e: - logger.error( - f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}" - ) + logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}") return "" - + except aiohttp.ClientError as e: - logger.error( - f"[{self.model_name}] HTTP client error in generate_response: {e}" - ) + logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}") return "" except Exception as e: - logger.error( - f"[{self.model_name}] Unexpected error in generate_response: {e}" - ) + logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") return "" @@ -1123,17 +1048,14 @@ class OpenRouterClient(BaseModelClient): model_name = f"openrouter/{model_name}" if model_name.startswith("openrouter-"): model_name = model_name.replace("openrouter-", "") - + super().__init__(model_name, prompts_dir=prompts_dir) self.api_key = os.environ.get("OPENROUTER_API_KEY") if not self.api_key: raise ValueError("OPENROUTER_API_KEY environment variable is required") - - self.client = AsyncOpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=self.api_key - ) - + + self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key) + logger.debug(f"[{self.model_name}] Initialized OpenRouter client") async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str: @@ -1150,26 +1072,23 @@ class OpenRouterClient(BaseModelClient): # Prepare standard OpenAI-compatible request response = await self.client.chat.completions.create( model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": prompt_with_cta} - ], + messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}], max_tokens=self.max_tokens, temperature=temperature, ) - + if not response.choices: logger.warning(f"[{self.model_name}] OpenRouter returned no choices") return "" - + content = response.choices[0].message.content.strip() if not content: logger.warning(f"[{self.model_name}] OpenRouter returned empty content") return "" - + # Parse or return the raw content return content - + except Exception as e: error_msg = str(e) # Check if it's a specific OpenRouter error @@ -1200,7 +1119,7 @@ class TogetherAIClient(BaseModelClient): self.api_key = os.environ.get("TOGETHER_API_KEY") if not self.api_key: raise ValueError("TOGETHER_API_KEY environment variable is required for TogetherAIClient") - + # The model_name passed to super() is used for logging and identification. # The actual model name for the API call is self.model_name (from super class). self.client = AsyncTogether(api_key=self.api_key) @@ -1211,7 +1130,7 @@ class TogetherAIClient(BaseModelClient): Generates a response from the Together AI model. """ logger.debug(f"[{self.model_name}] Generating response with prompt (first 100 chars): {prompt[:100]}...") - + messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}, @@ -1226,7 +1145,7 @@ class TogetherAIClient(BaseModelClient): # Consider adding max_tokens, temperature, etc. as needed # max_tokens=2048, # Example ) - + if response.choices and response.choices[0].message and response.choices[0].message.content is not None: content = response.choices[0].message.content logger.debug(f"[{self.model_name}] Received response (first 100 chars): {content[:100]}...") @@ -1236,16 +1155,16 @@ class TogetherAIClient(BaseModelClient): return "" except TogetherAPIError as e: logger.error(f"[{self.model_name}] Together AI API error: {e}", exc_info=True) - return f"Error: Together AI API error - {str(e)}" # Return a string with error info + return f"Error: Together AI API error - {str(e)}" # Return a string with error info except Exception as e: logger.error(f"[{self.model_name}] Unexpected error in TogetherAIClient: {e}", exc_info=True) - return f"Error: Unexpected error - {str(e)}" # Return a string with error info + return f"Error: Unexpected error - {str(e)}" # Return a string with error info ############################################################################## # RequestsOpenAIClient – sync requests, wrapped async (original + api_key) ############################################################################## -import requests, asyncio + class RequestsOpenAIClient(BaseModelClient): """ @@ -1266,18 +1185,14 @@ class RequestsOpenAIClient(BaseModelClient): if not self.api_key: raise ValueError("OPENAI_API_KEY missing and no inline key provided") - self.base_url = ( - base_url - or os.environ.get("OPENAI_BASE_URL") - or "https://api.openai.com/v1" - ).rstrip("/") + self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/") self.endpoint = f"{self.base_url}/chat/completions" # ---------------- internal blocking helper ---------------- # def _post_sync(self, payload: dict) -> dict: headers = { - "Content-Type": "application/json", + "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60) @@ -1291,17 +1206,13 @@ class RequestsOpenAIClient(BaseModelClient): temperature: float = 0.0, inject_random_seed: bool = True, ) -> str: - system_prompt_content = ( - f"{generate_random_seed()}\n\n{self.system_prompt}" - if inject_random_seed - else self.system_prompt - ) + system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt payload = { "model": self.model_name, "messages": [ {"role": "system", "content": system_prompt_content}, - {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"}, + {"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"}, ], "temperature": temperature, "max_tokens": self.max_tokens, @@ -1322,15 +1233,15 @@ class RequestsOpenAIClient(BaseModelClient): return "" - ############################################################################## # 3) Factory to Load Model Client ############################################################################## class ModelSpec(NamedTuple): - prefix: Optional[str] # 'openai', 'requests', … - model: str # 'gpt-4o' - base: Optional[str] # 'https://proxy.foo' - key: Optional[str] # 'sk-…' (may be None) + prefix: Optional[str] # 'openai', 'requests', … + model: str # 'gpt-4o' + base: Optional[str] # 'https://proxy.foo' + key: Optional[str] # 'sk-…' (may be None) + def _parse_model_spec(raw: str) -> ModelSpec: """ @@ -1339,11 +1250,11 @@ def _parse_model_spec(raw: str) -> ModelSpec: """ raw = raw.strip() - pre_hash, _, key_part = raw.partition("#") - pre_at, _, base_part = pre_hash.partition("@") + pre_hash, _, key_part = raw.partition("#") + pre_at, _, base_part = pre_hash.partition("@") maybe_pref, sep, model_part = pre_at.partition(":") - if sep: # explicit prefix was present + if sep: # explicit prefix was present prefix, model = maybe_pref.lower(), model_part else: prefix, model = None, maybe_pref @@ -1379,17 +1290,17 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM match spec.prefix: case "openai" | "oai": return OpenAIClient( - model_name = spec.model, - prompts_dir = prompts_dir, - base_url = spec.base, - api_key = inline_key, + model_name=spec.model, + prompts_dir=prompts_dir, + base_url=spec.base, + api_key=inline_key, ) case "requests" | "req": return RequestsOpenAIClient( - model_name = spec.model, - prompts_dir = prompts_dir, - base_url = spec.base, - api_key = inline_key, + model_name=spec.model, + prompts_dir=prompts_dir, + base_url=spec.base, + api_key=inline_key, ) case "responses" | "oai-resp" | "openai-responses": return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key) @@ -1409,7 +1320,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM # ------------------------------------------------------------------ # # 2. Heuristic fallback path (identical to the original behaviour) # # ------------------------------------------------------------------ # - lower_id = spec.model.lower() + lower_id = spec.model.lower() if lower_id == "o3-pro": return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key) @@ -1432,17 +1343,15 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM # Default: OpenAI-compatible async client return OpenAIClient( - model_name = spec.model, - prompts_dir = prompts_dir, - base_url = spec.base, - api_key = inline_key, + model_name=spec.model, + prompts_dir=prompts_dir, + base_url=spec.base, + api_key=inline_key, ) return OpenAIClient(model_name, prompts_dir, base_url) - - ############################################################################## # 1) Add a method to filter visible messages (near top-level or in BaseModelClient) ############################################################################## @@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name): visible = [] for msg in conversation_messages: # GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage - if ( - msg["recipient"] == "ALL" - or msg["recipient"] == "GLOBAL" - or msg["sender"] == power_name - or msg["recipient"] == power_name - ): + if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name: visible.append(msg) return visible # already in chronological order if appended that way diff --git a/ai_diplomacy/diary_logic.py b/ai_diplomacy/diary_logic.py index aa0be0f..1cfa164 100644 --- a/ai_diplomacy/diary_logic.py +++ b/ai_diplomacy/diary_logic.py @@ -11,25 +11,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None: """A local copy of the helper from agent.py to avoid circular imports.""" import os + try: if prompts_dir: filepath = os.path.join(prompts_dir, filename) else: current_dir = os.path.dirname(os.path.abspath(__file__)) - default_prompts_dir = os.path.join(current_dir, 'prompts') + default_prompts_dir = os.path.join(current_dir, "prompts") filepath = os.path.join(default_prompts_dir, filename) - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, "r", encoding="utf-8") as f: return f.read() except Exception as e: logger.error(f"Error loading prompt file {filepath}: {e}") return None + async def run_diary_consolidation( - agent: 'DiplomacyAgent', + agent: "DiplomacyAgent", game: "Game", log_file_path: str, entries_to_keep_unsummarized: int = 6, @@ -39,70 +42,42 @@ async def run_diary_consolidation( Consolidate older diary entries while keeping recent ones. This is the logic moved from the DiplomacyAgent class. """ - logger.info( - f"[{agent.power_name}] CONSOLIDATION START — " - f"{len(agent.full_private_diary)} total full entries" - ) + logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries") - full_entries = [ - e for e in agent.full_private_diary - if not e.startswith("[CONSOLIDATED HISTORY]") - ] + full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")] if len(full_entries) <= entries_to_keep_unsummarized: agent.private_diary = list(agent.full_private_diary) - logger.info( - f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — " - "skipping consolidation" - ) + logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation") return boundary_entry = full_entries[-entries_to_keep_unsummarized] match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry) if not match: - logger.error( - f"[{agent.power_name}] Could not parse year from boundary entry; " - "aborting consolidation" - ) + logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation") agent.private_diary = list(agent.full_private_diary) return cutoff_year = int(match.group(1)) - logger.info( - f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}" - ) + logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}") def _entry_year(entry: str) -> int | None: m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry) return int(m.group(1)) if m else None - entries_to_summarize = [ - e for e in full_entries - if (_entry_year(e) is not None and _entry_year(e) < cutoff_year) - ] - entries_to_keep = [ - e for e in full_entries - if (_entry_year(e) is None or _entry_year(e) >= cutoff_year) - ] + entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)] + entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)] - logger.info( - f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; " - f"keeping {len(entries_to_keep)} recent entries verbatim" - ) + logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim") if not entries_to_summarize: agent.private_diary = list(agent.full_private_diary) - logger.warning( - f"[{agent.power_name}] No eligible entries to summarise; " - "context diary left unchanged" - ) + logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged") return prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) if not prompt_template: - logger.error( - f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting" - ) + logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting") return prompt = prompt_template.format( @@ -132,27 +107,18 @@ async def run_diary_consolidation( new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}" agent.private_diary = [new_summary_entry] + entries_to_keep success_flag = "TRUE" - logger.info( - f"[{agent.power_name}] Consolidation complete — " - f"{len(agent.private_diary)} context entries now" - ) + logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now") except Exception as exc: - logger.error( - f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True - ) + logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True) finally: log_llm_response( log_file_path=log_file_path, - model_name=( - consolidation_client.model_name - if consolidation_client is not None - else agent.client.model_name - ), + model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name), power_name=agent.power_name, phase=game.current_short_phase, response_type="diary_consolidation", raw_input_prompt=prompt, raw_response=raw_response, success=success_flag, - ) \ No newline at end of file + ) diff --git a/ai_diplomacy/formatter.py b/ai_diplomacy/formatter.py index d61d5c4..20edfec 100644 --- a/ai_diplomacy/formatter.py +++ b/ai_diplomacy/formatter.py @@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses. """ -import json import logging -import os -from typing import Dict, Any, Optional +from typing import Optional from pathlib import Path # Import logging function and model configuration from .utils import log_llm_response, get_special_models + # Import client loading function from .clients import load_model_client @@ -26,77 +25,73 @@ FORMAT_ORDER_DIARY = "order_diary" async def format_with_gemini_flash( - raw_response: str, - format_type: str, - power_name: Optional[str] = None, - phase: Optional[str] = None, - log_file_path: Optional[str] = None + raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None ) -> str: """ Format a natural language response into required JSON structure using Gemini Flash. - + Args: raw_response: Natural language response from primary LLM format_type: Type of formatting required (e.g., FORMAT_ORDERS, FORMAT_STATE_UPDATE) power_name: Optional power name for logging phase: Optional phase for logging log_file_path: Optional path for CSV logging - + Returns: JSON string in the expected format """ # Map format types to prompt files format_prompts = { FORMAT_STATE_UPDATE: "formatting/format_state_update.txt", - FORMAT_CONVERSATION: "formatting/format_conversation.txt", + FORMAT_CONVERSATION: "formatting/format_conversation.txt", FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt", FORMAT_ORDERS: "formatting/format_orders.txt", FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt", - FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt" + FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt", } - + if format_type not in format_prompts: raise ValueError(f"Unknown format type: {format_type}") - + # Load the formatting prompt prompt_file = Path(__file__).parent / "prompts" / format_prompts[format_type] if not prompt_file.exists(): raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}") - - with open(prompt_file, 'r') as f: + + with open(prompt_file, "r") as f: format_prompt = f.read() - + # Replace placeholder with actual response format_prompt = format_prompt.replace("[RAW_RESPONSE]", raw_response) - + # Get model name from configuration special_models = get_special_models() model_name = special_models["formatter"] - + # Load the formatter client using the same logic as other models formatter_client = load_model_client(model_name) - + try: logger.info(f"[FORMATTER] Calling {model_name} for {format_type} formatting") - + # Create the full prompt with system message system_content = "You are a precise formatting assistant. Extract and format information exactly as requested." formatter_client.set_system_prompt(system_content) - + # Use the client's generate_response method formatted_response = await formatter_client.generate_response( prompt=format_prompt, temperature=0, # Deterministic formatting - inject_random_seed=False # No need for random seed in formatting + inject_random_seed=False, # No need for random seed in formatting ) - + if not formatted_response: logger.warning(f"[FORMATTER] {model_name} returned empty response") return "" - + # Log successful formatting logger.info(f"[FORMATTER] Successfully formatted {format_type} response") - + # Strip any markdown formatting that Gemini might add if formatted_response.startswith("```json"): formatted_response = formatted_response[7:] @@ -105,7 +100,7 @@ async def format_with_gemini_flash( if formatted_response.endswith("```"): formatted_response = formatted_response[:-3] formatted_response = formatted_response.strip() - + # Log if requested if log_file_path: log_llm_response( @@ -116,14 +111,14 @@ async def format_with_gemini_flash( response_type=f"format_{format_type}", raw_input_prompt=format_prompt, raw_response=formatted_response, - success="Success" + success="Success", ) - + return formatted_response - + except Exception as e: logger.error(f"[FORMATTER] Error calling Gemini Flash: {e}") - + # Log error if requested if log_file_path: log_llm_response( @@ -134,11 +129,11 @@ async def format_with_gemini_flash( response_type=f"format_{format_type}", raw_input_prompt=format_prompt, raw_response=f"ERROR: {str(e)}", - success=f"Failure: {type(e).__name__}" + success=f"Failure: {type(e).__name__}", ) - + # Return empty structure based on format type if format_type == FORMAT_CONVERSATION: return "[]" else: - return "{}" \ No newline at end of file + return "{}" diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py index 27abd23..c0bf78e 100644 --- a/ai_diplomacy/game_history.py +++ b/ai_diplomacy/game_history.py @@ -3,7 +3,6 @@ import logging from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, List, Optional -from pprint import pformat logger = logging.getLogger("utils") logger.setLevel(logging.INFO) @@ -24,12 +23,8 @@ class Phase: plans: Dict[str, str] = field(default_factory=dict) messages: List[Message] = field(default_factory=list) submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict) - orders_by_power: Dict[str, List[str]] = field( - default_factory=lambda: defaultdict(list) - ) - results_by_power: Dict[str, List[List[str]]] = field( - default_factory=lambda: defaultdict(list) - ) + orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list)) + results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list)) # NEW: Store phase-end summaries provided by each power phase_summaries: Dict[str, str] = field(default_factory=dict) # NEW: Store experience/journal updates from each power for this phase @@ -37,11 +32,9 @@ class Phase: def add_plan(self, power_name: str, plan: str): self.plans[power_name] = plan - + def add_message(self, sender: str, recipient: str, content: str): - self.messages.append( - Message(sender=sender, recipient=recipient, content=content) - ) + self.messages.append(Message(sender=sender, recipient=recipient, content=content)) def add_orders(self, power: str, orders: List[str], results: List[List[str]]): self.orders_by_power[power].extend(orders) @@ -110,14 +103,10 @@ class GameHistory: phase.plans[power_name] = plan logger.debug(f"Added plan for {power_name} in {phase_name}") - def add_message( - self, phase_name: str, sender: str, recipient: str, message_content: str - ): + def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str): phase = self._get_phase(phase_name) if phase: - message = Message( - sender=sender, recipient=recipient, content=message_content - ) + message = Message(sender=sender, recipient=recipient, content=message_content) phase.messages.append(message) logger.debug(f"Added message from {sender} to {recipient} in {phase_name}") @@ -147,16 +136,12 @@ class GameHistory: phase.experience_updates[power_name] = update logger.debug(f"Added experience update for {power_name} in {phase_name}") - def get_strategic_directives(self): + def get_strategic_directives(self): # returns for last phase only if exists - if not self.phases: + if not self.phases: return {} return self.phases[-1].plans - - - - def get_order_history_for_prompt( self, game: "Game", @@ -194,10 +179,7 @@ class GameHistory: return str(tag).strip().lower() engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()} - eng2code = { - "AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", - "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR" - } + eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"} norm = game.map.norm out_lines = ["**ORDER HISTORY (Recent Rounds)**"] @@ -209,14 +191,14 @@ class GameHistory: for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)): submitted = ph.submitted_orders_by_power.get(pwr, []) - accepted = ph.orders_by_power.get(pwr, []) + accepted = ph.orders_by_power.get(pwr, []) if isinstance(submitted, str): submitted = [submitted] if isinstance(accepted, str): accepted = [accepted] - def _norm_keep(o): # keep WAIVE readable + def _norm_keep(o): # keep WAIVE readable return o if o.upper() == "WAIVE" else norm(o) sub_norm = {_norm_keep(o): o for o in submitted} @@ -264,13 +246,6 @@ class GameHistory: return "\n(No orders were issued in recent history)\n" return "\n".join(out_lines) - - - - - - - def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str: current_phase: Optional[Phase] = None for phase_obj in self.phases: @@ -281,7 +256,7 @@ class GameHistory: if not current_phase: return f"\n(No messages found for current phase: {current_phase_name})\n" - messages_str = "" + messages_str = "" global_msgs_content = current_phase.get_global_messages() if global_msgs_content: @@ -299,7 +274,7 @@ class GameHistory: messages_str += "\n" else: messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n (No private messages this round)\n" - + if not global_msgs_content and not private_msgs_dict: return f"\n(No messages recorded for current phase: {current_phase_name})\n" @@ -313,10 +288,10 @@ class GameHistory: """ if not self.phases: return [] - + # Get the most recent 2 phases including current phase recent_phases = self.phases[-2:] if len(self.phases) >= 2 else self.phases[-1:] - + # Collect all messages sent TO this power messages_to_power = [] for phase in recent_phases: @@ -325,48 +300,44 @@ class GameHistory: if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name): # Skip if sender is this power (don't need to respond to own messages) if msg.sender != power_name: - messages_to_power.append({ - 'sender': msg.sender, - 'content': msg.content, - 'phase': phase.name - }) - + messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name}) + # Add debug logging logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases") if not messages_to_power: logger.info(f"No messages found for {power_name} to respond to") - + # Take the most recent 'limit' messages return messages_to_power[-limit:] if messages_to_power else [] - + def get_ignored_messages_by_power(self, sender_name: str, num_phases: int = 3) -> Dict[str, List[Dict[str, str]]]: """ Identifies which powers are not responding to messages from sender_name. Returns a dict mapping power names to their ignored messages. - + A message is considered ignored if: 1. It was sent from sender_name to another power (private) 2. No response from that power was received in the same or next phase """ ignored_by_power = {} - + # Get recent phases recent_phases = self.phases[-num_phases:] if self.phases else [] if not recent_phases: return ignored_by_power - + for i, phase in enumerate(recent_phases): # Get messages sent by sender to specific powers (not global) sender_messages = [] for msg in phase.messages: # Handle both Message objects and dict objects if isinstance(msg, Message): - if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']: + if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]: sender_messages.append(msg) else: # Assume dict - if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']: + if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]: sender_messages.append(msg) - + # Check for responses in this and next phases for msg in sender_messages: # Handle both Message objects and dict objects @@ -374,37 +345,36 @@ class GameHistory: recipient = msg.recipient msg_content = msg.content else: - recipient = msg['recipient'] - msg_content = msg['content'] - + recipient = msg["recipient"] + msg_content = msg["content"] + # Look for responses in current phase and next phases found_response = False - + # Check remaining phases starting from current - for check_phase in recent_phases[i:min(i+2, len(recent_phases))]: + for check_phase in recent_phases[i : min(i + 2, len(recent_phases))]: # Look for messages FROM the recipient TO the sender (direct response) # or FROM the recipient to GLOBAL/ALL that might acknowledge sender response_msgs = [] for m in check_phase.messages: if isinstance(m, Message): - if m.sender == recipient and (m.recipient == sender_name or - (m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)): + if m.sender == recipient and ( + m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content) + ): response_msgs.append(m) else: # Assume dict - if m['sender'] == recipient and (m['recipient'] == sender_name or - (m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))): + if m["sender"] == recipient and ( + m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", "")) + ): response_msgs.append(m) - + if response_msgs: found_response = True break - + if not found_response: if recipient not in ignored_by_power: ignored_by_power[recipient] = [] - ignored_by_power[recipient].append({ - 'phase': phase.name, - 'content': msg_content - }) - - return ignored_by_power \ No newline at end of file + ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content}) + + return ignored_by_power diff --git a/ai_diplomacy/game_logic.py b/ai_diplomacy/game_logic.py index 855c426..8a44871 100644 --- a/ai_diplomacy/game_logic.py +++ b/ai_diplomacy/game_logic.py @@ -3,7 +3,7 @@ import logging import os import json import asyncio -from typing import Dict, List, Tuple, Optional, Any +from typing import Dict, Tuple, Optional, Any from argparse import Namespace from diplomacy import Game @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) # --- Serialization / Deserialization --- + def serialize_agent(agent: DiplomacyAgent) -> dict: """Converts an agent object to a JSON-serializable dictionary.""" return { @@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict: "private_diary": agent.private_diary, } -def deserialize_agent( - agent_data: dict, - prompts_dir: Optional[str] = None, - *, - override_model_id: Optional[str] = None -) -> DiplomacyAgent: + +def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent: """ Recreates an agent object from a dictionary. @@ -68,34 +65,31 @@ def deserialize_agent( # --- State Management --- # game_logic.py -_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments +_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments + def _next_phase_name(short: str) -> str: """ Return the Diplomacy phase string that chronologically follows *short*. (E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M) """ - season = short[0] # 'S' | 'W' - year = int(short[1:5]) - typ = short[-1] # 'M' | 'R' | 'A' + season = short[0] # 'S' | 'W' + year = int(short[1:5]) + typ = short[-1] # 'M' | 'R' | 'A' idx = _PHASE_ORDER.index(typ) - if idx < 2: # still in the same season - return f"{season}{year}{_PHASE_ORDER[idx+1]}" + if idx < 2: # still in the same season + return f"{season}{year}{_PHASE_ORDER[idx + 1]}" # typ was 'A' → roll season - if season == "S": # summer → winter, same year + if season == "S": # summer → winter, same year return f"W{year}M" - else: # winter→ spring, next year - return f"S{year+1}M" + else: # winter→ spring, next year + return f"S{year + 1}M" + def save_game_state( - game: Game, - agents: Dict[str, DiplomacyAgent], - game_history: GameHistory, - output_path: str, - run_config: Namespace, - completed_phase_name: str + game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str ): """ Serialise the entire game to JSON, preserving per-phase custom metadata @@ -130,9 +124,7 @@ def save_game_state( if extras: previous_phase_extras[phase["name"]] = extras except Exception as exc: - logger.warning( - "Could not load previous save to retain metadata: %s", exc, exc_info=True - ) + logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True) # -------------------------------------------------------------- # # 2. Build the fresh base structure from the diplomacy library. # @@ -143,16 +135,12 @@ def save_game_state( # 3. Walk every phase and merge the metadata back in. # # -------------------------------------------------------------- # # Capture the *current* snapshot of every live agent exactly once. - current_state_agents = { - p_name: serialize_agent(p_agent) - for p_name, p_agent in agents.items() - if not game.powers[p_name].is_eliminated() - } + current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()} for phase_block in saved_game.get("phases", []): if int(phase_block["name"][1:5]) > run_config.max_year: break - + phase_name = phase_block["name"] # 3a. Re-attach anything we cached from a previous save. @@ -167,31 +155,24 @@ def save_game_state( # Plans for this phase – may be empty in non-movement phases. phase_obj = game_history._get_phase(phase_name) - phase_block["state_history_plans"] = ( - phase_obj.plans if phase_obj else {} - ) - + phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {} # -------------------------------------------------------------- # # 4. Attach top-level metadata and write atomically. # # -------------------------------------------------------------- # saved_game["phase_summaries"] = getattr(game, "phase_summaries", {}) - saved_game["final_agent_states"] = { - p_name: {"relationships": a.relationships, "goals": a.goals} - for p_name, a in agents.items() - } + saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()} # Filter out phases > max_year - #saved_game["phases"] = [ + # saved_game["phases"] = [ # ph for ph in saved_game["phases"] # if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example - #] + # ] atomic_write_json(saved_game, output_path) logger.info("Game state saved successfully.") - def load_game_state( run_dir: str, game_file_name: str, @@ -204,7 +185,7 @@ def load_game_state( raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}") logger.info(f"Loading game state from: {game_file_path}") - with open(game_file_path, 'r') as f: + with open(game_file_path, "r") as f: saved_game_data = json.load(f) # If resuming, find the specified phase and truncate the data after it @@ -213,49 +194,49 @@ def load_game_state( try: # Find the index of the phase *before* the one we want to resume from. # We will start the simulation *at* the resume_from_phase. - resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase) + resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase) # Truncate the list to exclude everything after the resume phase # Note: the state saved for a given phase represents the state at the beginning of that phase. - saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1] + saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1] # Wipe any data that must be regenerated. for key in ("orders", "results", "messages"): - saved_game_data['phases'][-1].pop(key, None) + saved_game_data["phases"][-1].pop(key, None) logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.") except StopIteration: # If the phase is not found, maybe it's the first phase (S1901M) if resume_from_phase == "S1901M": - saved_game_data['phases'] = [] - logger.info("Resuming from S1901M. Starting with a clean history.") + saved_game_data["phases"] = [] + logger.info("Resuming from S1901M. Starting with a clean history.") else: raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.") # Reconstruct the Game object - last_phase = saved_game_data['phases'][-1] + last_phase = saved_game_data["phases"][-1] # Wipe the data that must be regenerated **but preserve the keys** - last_phase['orders'] = {} # was dict - last_phase['results'] = {} # was dict - last_phase['messages'] = [] + last_phase["orders"] = {} # was dict + last_phase["results"] = {} # was dict + last_phase["messages"] = [] game = from_saved_game_format(saved_game_data) - game.phase_summaries = saved_game_data.get('phase_summaries', {}) + game.phase_summaries = saved_game_data.get("phase_summaries", {}) # Reconstruct agents and game history from the *last* valid phase in the data - if not saved_game_data['phases']: + if not saved_game_data["phases"]: # This happens if we are resuming from the very beginning (S1901M) logger.info("No previous phases found. Initializing fresh agents and history.") - agents = {} # Will be created by the main loop + agents = {} # Will be created by the main loop game_history = GameHistory() else: # We save the game state up to & including the current (uncompleted) phase. # So we need to grab the agent state from the previous (completed) phase. - if len(saved_game_data['phases']) <= 1: + if len(saved_game_data["phases"]) <= 1: last_phase_data = {} else: - last_phase_data = saved_game_data['phases'][-2] - + last_phase_data = saved_game_data["phases"][-2] + # -------------------- Rebuild agents -------------------- # agents = {} @@ -269,18 +250,11 @@ def load_game_state( elif len(provided) == 1: power_model_map = dict(zip(powers_order, provided * len(powers_order))) else: - raise ValueError( - f"Invalid --models argument: expected 1 or {len(powers_order)} " - f"items, got {len(provided)}." - ) + raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.") if "state_agents" in last_phase_data: logger.info("Rebuilding agents from saved state...") - prompts_dir_from_config = ( - run_config.prompts_dir - if run_config and hasattr(run_config, "prompts_dir") - else None - ) + prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None for power_name, agent_data in last_phase_data["state_agents"].items(): override_id = power_model_map.get(power_name) agents[power_name] = deserialize_agent( @@ -290,40 +264,30 @@ def load_game_state( ) logger.info(f"Rebuilt {len(agents)} agents.") else: - raise ValueError( - "Cannot resume: 'state_agents' key not found in the last phase " - "of the save file." - ) - + raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.") # Rebuild GameHistory game_history = GameHistory() logger.info("Rebuilding game history...") - for phase_data in saved_game_data['phases'][:-1]: - phase_name = phase_data['name'] + for phase_data in saved_game_data["phases"][:-1]: + phase_name = phase_data["name"] game_history.add_phase(phase_name) # Add messages - for msg in phase_data.get('messages', []): - game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message']) + for msg in phase_data.get("messages", []): + game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"]) # Add plans - if 'state_history_plans' in phase_data: - for p_name, plan in phase_data['state_history_plans'].items(): + if "state_history_plans" in phase_data: + for p_name, plan in phase_data["state_history_plans"].items(): game_history.add_plan(phase_name, p_name, plan) logger.info("Game history rebuilt.") - return game, agents, game_history, run_config -async def initialize_new_game( - args: Namespace, - game: Game, - game_history: GameHistory, - llm_log_file_path: str -) -> Dict[str, DiplomacyAgent]: +async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]: """Initializes agents for a new game.""" powers_order = sorted(list(ALL_POWERS)) - + # Parse token limits default_max_tokens = args.max_tokens model_max_tokens = {p: default_max_tokens for p in powers_order} @@ -339,7 +303,7 @@ async def initialize_new_game( # Handle power model mapping if args.models: - provided_models = [name.strip() for name in args.models.split(",")] + provided_models = [name.strip() for name in args.models.split(",")] if len(provided_models) == len(powers_order): game.power_model_map = dict(zip(powers_order, provided_models)) elif len(provided_models) == 1: @@ -364,17 +328,17 @@ async def initialize_new_game( initialization_tasks.append(initialize_agent_state_ext(agent, game, game_history, llm_log_file_path, prompts_dir=args.prompts_dir)) except Exception as e: logger.error(f"Failed to create agent or client for {power_name} with model {model_id}: {e}", exc_info=True) - + logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...") initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True) - + initialized_powers = list(agents.keys()) for i, result in enumerate(initialization_results): - if i < len(initialized_powers): - power_name = initialized_powers[i] - if isinstance(result, Exception): - logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result) - else: - logger.info(f"Successfully initialized agent state for {power_name}.") - - return agents \ No newline at end of file + if i < len(initialized_powers): + power_name = initialized_powers[i] + if isinstance(result, Exception): + logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result) + else: + logger.info(f"Successfully initialized agent state for {power_name}.") + + return agents diff --git a/ai_diplomacy/initialization.py b/ai_diplomacy/initialization.py index 5f460cb..32b6be1 100644 --- a/ai_diplomacy/initialization.py +++ b/ai_diplomacy/initialization.py @@ -1,12 +1,11 @@ # ai_diplomacy/initialization.py import logging import json -import os from typing import Optional from ..config import config # Forward declaration for type hinting, actual imports in function if complex -if False: # TYPE_CHECKING +if False: # TYPE_CHECKING from diplomacy import Game from diplomacy.models.game import GameHistory from .agent import DiplomacyAgent @@ -18,40 +17,40 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE logger = logging.getLogger(__name__) + async def initialize_agent_state_ext( - agent: 'DiplomacyAgent', - game: 'Game', - game_history: 'GameHistory', + agent: "DiplomacyAgent", + game: "Game", + game_history: "GameHistory", log_file_path: str, prompts_dir: Optional[str] = None, ): """Uses the LLM to set initial goals and relationships for the agent.""" power_name = agent.power_name - logger.info(f"[{power_name}] Initializing agent state using LLM (external function)..." ) + logger.info(f"[{power_name}] Initializing agent state using LLM (external function)...") current_phase = game.get_current_phase() if game else "UnknownPhase" full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block - response = "" # Ensure response is defined for finally block - success_status = "Failure: Initialized" # Default status + response = "" # Ensure response is defined for finally block + success_status = "Failure: Initialized" # Default status try: # Load the prompt template allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS) - initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir) + initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir) if not initial_prompt_template: logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.") return - + # Format the prompt with variables - initial_prompt = initial_prompt_template.format( - power_name=power_name, - allowed_labels_str=allowed_labels_str - ) + initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str) board_state = game.get_state() if game else {} possible_orders = game.get_all_possible_orders() if game else {} - logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}") + logger.debug( + f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}" + ) # Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility # For initialization, game should always be present. @@ -59,12 +58,12 @@ async def initialize_agent_state_ext( context = build_context_prompt( game=game, - board_state=board_state, + board_state=board_state, power_name=power_name, - possible_orders=possible_orders, - game_history=game_history, - agent_goals=None, - agent_relationships=None, + possible_orders=possible_orders, + game_history=game_history, + agent_goals=None, + agent_relationships=None, agent_private_diary=formatted_diary, prompts_dir=prompts_dir, ) @@ -76,9 +75,9 @@ async def initialize_agent_state_ext( log_file_path=log_file_path, power_name=power_name, phase=current_phase, - response_type='initialization', # Context for run_llm_and_log internal error logging + response_type="initialization", # Context for run_llm_and_log internal error logging ) - logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet + logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet parsed_successfully = False try: @@ -86,11 +85,7 @@ async def initialize_agent_state_ext( if config.USE_UNFORMATTED_PROMPTS: # Format the natural language response into JSON formatted_response = await format_with_gemini_flash( - response, - FORMAT_INITIAL_STATE, - power_name=power_name, - phase=current_phase, - log_file_path=log_file_path + response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path ) else: # Use the raw response directly (already formatted) @@ -101,20 +96,24 @@ async def initialize_agent_state_ext( except json.JSONDecodeError as e: logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...") success_status = "Failure: JSONDecodeError" - update_data = {} # Ensure update_data exists for fallback logic below - parsed_successfully = False # Explicitly set here too + update_data = {} # Ensure update_data exists for fallback logic below + parsed_successfully = False # Explicitly set here too # Fallback logic for goals/relationships will be handled later if update_data is empty # Defensive check for update_data type if parsing was initially considered successful - if parsed_successfully: + if parsed_successfully: if isinstance(update_data, str): - logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...") - update_data = {} # Treat as parsing failure + logger.error( + f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..." + ) + update_data = {} # Treat as parsing failure parsed_successfully = False success_status = "Failure: ParsedAsStr" - elif not isinstance(update_data, dict): # Expecting a dict from JSON object - logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}") - update_data = {} # Treat as parsing failure + elif not isinstance(update_data, dict): # Expecting a dict from JSON object + logger.error( + f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}" + ) + update_data = {} # Treat as parsing failure parsed_successfully = False success_status = "Failure: NotADict" @@ -122,8 +121,8 @@ async def initialize_agent_state_ext( initial_relationships_applied = False if parsed_successfully: - initial_goals = update_data.get('initial_goals') or update_data.get('goals') - initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships') + initial_goals = update_data.get("initial_goals") or update_data.get("goals") + initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships") if isinstance(initial_goals, list) and initial_goals: agent.goals = initial_goals @@ -152,30 +151,30 @@ async def initialize_agent_state_ext( else: logger.warning(f"[{power_name}] No valid relationships found in LLM response.") else: - logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).") - + logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).") + if initial_goals_applied or initial_relationships_applied: success_status = "Success: Applied LLM data" - elif parsed_successfully: # Parsed but nothing useful to apply + elif parsed_successfully: # Parsed but nothing useful to apply success_status = "Success: Parsed but no data applied" # If not parsed_successfully, success_status is already "Failure: JSONDecodeError" # Fallback if LLM data was not applied or parsing failed if not initial_goals_applied: - if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM + if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"] agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.") logger.info(f"[{power_name}] Default goals set.") - + if not initial_relationships_applied: - # Check if relationships are still default-like before overriding + # Check if relationships are still default-like before overriding is_default_relationships = True - if agent.relationships: # Check if it's not empty + if agent.relationships: # Check if it's not empty for p in ALL_POWERS: if p != power_name and agent.relationships.get(p) != "Neutral": is_default_relationships = False break - if is_default_relationships: + if is_default_relationships: agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name} agent.add_journal_entry(f"[{current_phase}] Set default neutral relationships as LLM provided none valid or parse failed.") logger.info(f"[{power_name}] Default neutral relationships set.") @@ -191,16 +190,16 @@ async def initialize_agent_state_ext( agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name} logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}") finally: - if log_file_path: # Ensure log_file_path is provided + if log_file_path: # Ensure log_file_path is provided log_llm_response( log_file_path=log_file_path, model_name=agent.client.model_name if agent and agent.client else "UnknownModel", power_name=power_name, phase=current_phase, - response_type="initial_state_setup", # Specific type for CSV logging + response_type="initial_state_setup", # Specific type for CSV logging raw_input_prompt=full_prompt, raw_response=response, - success=success_status + success=success_status, ) # Final log of state after initialization attempt diff --git a/ai_diplomacy/narrative.py b/ai_diplomacy/narrative.py index ef7f087..831d094 100644 --- a/ai_diplomacy/narrative.py +++ b/ai_diplomacy/narrative.py @@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment] LOGGER.info("Game._generate_phase_summary patched with narrative generation.") - diff --git a/ai_diplomacy/negotiations.py b/ai_diplomacy/negotiations.py index 4c471aa..3259337 100644 --- a/ai_diplomacy/negotiations.py +++ b/ai_diplomacy/negotiations.py @@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING from diplomacy.engine.message import Message, GLOBAL from .agent import DiplomacyAgent -from .clients import load_model_client -from .utils import gather_possible_orders, load_prompt, normalize_recipient_name +from .utils import gather_possible_orders, normalize_recipient_name if TYPE_CHECKING: from .game_history import GameHistory @@ -21,9 +20,9 @@ load_dotenv() async def conduct_negotiations( - game: 'Game', + game: "Game", agents: Dict[str, DiplomacyAgent], - game_history: 'GameHistory', + game_history: "GameHistory", model_error_stats: Dict[str, Dict[str, int]], log_file_path: str, max_rounds: int = 3, @@ -35,13 +34,9 @@ async def conduct_negotiations( """ logger.info("Starting negotiation phase.") - active_powers = [ - p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() - ] - eliminated_powers = [ - p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated() - ] - + active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()] + eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()] + logger.info(f"Active powers for negotiations: {active_powers}") if eliminated_powers: logger.info(f"Eliminated powers (skipped): {eliminated_powers}") @@ -51,7 +46,7 @@ async def conduct_negotiations( # We do up to 'max_rounds' single-message turns for each power for round_index in range(max_rounds): logger.info(f"Negotiation Round {round_index + 1}/{max_rounds}") - + # Prepare tasks for asyncio.gather tasks = [] power_names_for_tasks = [] @@ -99,28 +94,28 @@ async def conduct_negotiations( # Process results for i, result in enumerate(results): power_name = power_names_for_tasks[i] - agent = agents[power_name] # Get agent again for journaling - model_name = agent.client.model_name # Get model name for stats + agent = agents[power_name] # Get agent again for journaling + model_name = agent.client.model_name # Get model name for stats if isinstance(result, Exception): logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result) # Use model_name for stats key if possible if model_name in model_error_stats: - model_error_stats[model_name]["conversation_errors"] += 1 - else: # Fallback to power_name if model name not tracked (shouldn't happen) - model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0) - model_error_stats[power_name]["conversation_errors"] += 1 - messages = [] # Treat as no messages on error - elif result is None: # Handle case where client might return None on internal error - logger.warning(f"Received None instead of messages for {power_name}.") - messages = [] - if model_name in model_error_stats: - model_error_stats[model_name]["conversation_errors"] += 1 - else: - model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0) - model_error_stats[power_name]["conversation_errors"] += 1 + model_error_stats[model_name]["conversation_errors"] += 1 + else: # Fallback to power_name if model name not tracked (shouldn't happen) + model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0) + model_error_stats[power_name]["conversation_errors"] += 1 + messages = [] # Treat as no messages on error + elif result is None: # Handle case where client might return None on internal error + logger.warning(f"Received None instead of messages for {power_name}.") + messages = [] + if model_name in model_error_stats: + model_error_stats[model_name]["conversation_errors"] += 1 + else: + model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0) + model_error_stats[power_name]["conversation_errors"] += 1 else: - messages = result # result is the list of message dicts + messages = result # result is the list of message dicts logger.debug(f"Received {len(messages)} message(s) from {power_name}.") # Process the received messages (same logic as before) @@ -134,27 +129,27 @@ async def conduct_negotiations( # Create an official message in the Diplomacy engine # Determine recipient based on message type if message.get("message_type") == "private": - recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow + recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow if recipient not in game.powers and recipient != GLOBAL: logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.") - recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid - else: # Assume global if not private or type is missing + recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid + else: # Assume global if not private or type is missing recipient = GLOBAL - + diplo_message = Message( phase=game.current_short_phase, sender=power_name, - recipient=recipient, # Use determined recipient - message=message.get("content", ""), # Use .get for safety - time_sent=None, # Let the engine assign time + recipient=recipient, # Use determined recipient + message=message.get("content", ""), # Use .get for safety + time_sent=None, # Let the engine assign time ) game.add_message(diplo_message) # Also add to our custom history game_history.add_message( game.current_short_phase, power_name, - recipient, # Use determined recipient here too - message.get("content", ""), # Use .get for safety + recipient, # Use determined recipient here too + message.get("content", ""), # Use .get for safety ) journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally" agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...") diff --git a/ai_diplomacy/planning.py b/ai_diplomacy/planning.py index 0d6b8a9..5531e5a 100644 --- a/ai_diplomacy/planning.py +++ b/ai_diplomacy/planning.py @@ -1,18 +1,17 @@ -from dotenv import load_dotenv import logging import concurrent.futures from typing import Dict -from .clients import load_model_client from .game_history import GameHistory from .agent import DiplomacyAgent logger = logging.getLogger(__name__) + async def planning_phase( - game, - agents: Dict[str, DiplomacyAgent], - game_history: GameHistory, + game, + agents: Dict[str, DiplomacyAgent], + game_history: GameHistory, model_error_stats, log_file_path: str, ): @@ -20,24 +19,18 @@ async def planning_phase( Lets each power generate a strategic plan using their DiplomacyAgent. """ logger.info(f"Starting planning phase for {game.current_short_phase}...") - active_powers = [ - p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() - ] - eliminated_powers = [ - p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated() - ] - + active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()] + eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()] + logger.info(f"Active powers for planning: {active_powers}") if eliminated_powers: logger.info(f"Eliminated powers (skipped): {eliminated_powers}") else: logger.info("No eliminated powers yet.") - + board_state = game.get_state() - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(active_powers) - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor: futures = {} for power_name in active_powers: if power_name not in agents: @@ -45,7 +38,7 @@ async def planning_phase( continue agent = agents[power_name] client = agent.client - + future = executor.submit( client.get_plan, game, @@ -67,20 +60,18 @@ async def planning_phase( try: plan_result = future.result() logger.info(f"Received planning result from {power_name}.") - + if plan_result.startswith("Error:"): - logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}") - if power_name in model_error_stats: - model_error_stats[power_name].setdefault('planning_generation_errors', 0) - model_error_stats[power_name]['planning_generation_errors'] += 1 - else: - model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0) - model_error_stats[f'{power_name}_planning_generation_errors'] += 1 + logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}") + if power_name in model_error_stats: + model_error_stats[power_name].setdefault("planning_generation_errors", 0) + model_error_stats[power_name]["planning_generation_errors"] += 1 + else: + model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0) + model_error_stats[f"{power_name}_planning_generation_errors"] += 1 elif plan_result: agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...") - game_history.add_plan( - game.current_short_phase, power_name, plan_result - ) + game_history.add_plan(game.current_short_phase, power_name, plan_result) logger.debug(f"Added plan for {power_name} to history.") else: logger.warning(f"Agent {power_name} returned an empty plan.") @@ -88,11 +79,11 @@ async def planning_phase( except Exception as e: logger.error(f"Exception during planning result processing for {power_name}: {e}") if power_name in model_error_stats: - model_error_stats[power_name].setdefault('planning_execution_errors', 0) - model_error_stats[power_name]['planning_execution_errors'] += 1 + model_error_stats[power_name].setdefault("planning_execution_errors", 0) + model_error_stats[power_name]["planning_execution_errors"] += 1 else: - model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0) - model_error_stats[f'{power_name}_planning_execution_errors'] += 1 - + model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0) + model_error_stats[f"{power_name}_planning_execution_errors"] += 1 + logger.info("Planning phase processing complete.") - return game_history \ No newline at end of file + return game_history diff --git a/ai_diplomacy/possible_order_context.py b/ai_diplomacy/possible_order_context.py index 39c0bb9..1c0b47e 100644 --- a/ai_diplomacy/possible_order_context.py +++ b/ai_diplomacy/possible_order_context.py @@ -8,12 +8,13 @@ import logging import re # Placeholder for actual map type from diplomacy.engine.map.Map -# GameMap = Any +# GameMap = Any # Type hint for board_state dictionary from game.get_state() # BoardState = Dict[str, Any] logger = logging.getLogger(__name__) + def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]: """ Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces. @@ -22,20 +23,18 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]: """ # ── collect all 3-letter province codes ─────────────────────────────── provs: Set[str] = { - loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL' + loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL' for loc in game_map.locs if len(loc.split("/")[0]) == 3 } - graph: Dict[str, Dict[str, List[str]]] = { - p: {"ARMY": [], "FLEET": []} for p in provs - } + graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs} # ── helper: list every concrete variant of a province ───────────────── def variants(code: str) -> List[str]: lst = list(game_map.loc_coasts.get(code, [])) if code not in lst: - lst.append(code) # ensure base node included + lst.append(code) # ensure base node included return lst # ── populate adjacency by brute-force queries to `abuts()` ──────────── @@ -55,48 +54,43 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]: graph[src]["ARMY"].append(dest) # FLEET — any src variant that can host a fleet is valid - if any( - game_map.abuts("F", sv, "-", dv) - for sv in src_vers - for dv in dest_vers - ): + if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers): graph[src]["FLEET"].append(dest) # ── tidy up duplicates / order --------------------------------------- for p in graph: - graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"])) + graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"])) graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"])) return graph - def bfs_shortest_path( - graph: Dict[str, Dict[str, List[str]]], - board_state: BoardState, - game_map: GameMap, # Added game_map - start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC' - unit_type: str, - is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc -) -> Optional[List[str]]: # Returns path of SHORT names + graph: Dict[str, Dict[str, List[str]]], + board_state: BoardState, + game_map: GameMap, # Added game_map + start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC' + unit_type: str, + is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc +) -> Optional[List[str]]: # Returns path of SHORT names """Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func.""" - + # Convert full start location to short province name start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full) - if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE. + if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE. start_loc_short = start_loc_short[:3] # If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key. # A simpler way for non-coastal full (like 'VIE') or already short: - if '/' not in start_loc_full: - start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR' - else: # Has '/', e.g. 'STP/SC' - start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP' + if "/" not in start_loc_full: + start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR' + else: # Has '/', e.g. 'STP/SC' + start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP' if start_loc_short not in graph: logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.") return None - queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])]) + queue: deque[Tuple[str, List[str]]] = deque([(start_loc_short, [start_loc_short])]) visited_nodes: Set[str] = {start_loc_short} while queue: @@ -104,14 +98,14 @@ def bfs_shortest_path( # is_target_func expects a short location name if is_target_func(current_loc_short, board_state): - return path # Path of short names + return path # Path of short names # possible_neighbors are SHORT names from the graph possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, []) - + for next_loc_short in possible_neighbors_short: if next_loc_short not in visited_nodes: - if next_loc_short not in graph: # Defensive check for neighbors not in graph keys + if next_loc_short not in graph: # Defensive check for neighbors not in graph keys logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.") continue visited_nodes.add(next_loc_short) @@ -119,11 +113,12 @@ def bfs_shortest_path( queue.append((next_loc_short, new_path)) return None + # --- Helper functions for context generation --- def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]: """Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None.""" - for power, unit_list in board_state.get('units', {}).items(): - for unit_str in unit_list: # e.g., "A PAR", "F STP/SC" + for power, unit_list in board_state.get("units", {}).items(): + for unit_str in unit_list: # e.g., "A PAR", "F STP/SC" parts = unit_str.split(" ") if len(parts) == 2: unit_map_loc = parts[1] @@ -131,27 +126,29 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str return f"{parts[0]} {location} ({power})" return None + def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]: """Returns the controlling power's name if the location is an SC, else None.""" # Normalize location to base province name, as SCs are tied to provinces, not specific coasts loc_province_name = game_map.loc_name.get(location, location).upper()[:3] if loc_province_name not in game_map.scs: return None - for power, sc_list in board_state.get('centers', {}).items(): + for power, sc_list in board_state.get("centers", {}).items(): if loc_province_name in sc_list: return power - return None # Unowned SC + return None # Unowned SC + def get_shortest_path_to_friendly_unit( - board_state: BoardState, + board_state: BoardState, graph: Dict[str, Dict[str, List[str]]], - game_map: GameMap, # Added game_map - power_name: str, - start_unit_loc_full: str, - start_unit_type: str + game_map: GameMap, # Added game_map + power_name: str, + start_unit_loc_full: str, + start_unit_type: str, ) -> Optional[Tuple[str, List[str]]]: """Finds the shortest path to any friendly unit of the same power.""" - + def is_target_friendly(loc_short: str, current_board_state: BoardState) -> bool: # loc_short is a short province name. Need to check all its full locations. full_locs_for_short = game_map.loc_coasts.get(loc_short, [loc_short]) @@ -162,7 +159,7 @@ def get_shortest_path_to_friendly_unit( return False path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly) - if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found + if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found target_loc_short = path_short_names[-1] # Find the actual friendly unit string at one of the full locations of target_loc_short friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT" @@ -177,41 +174,41 @@ def get_shortest_path_to_friendly_unit( def get_nearest_enemy_units( - board_state: BoardState, + board_state: BoardState, graph: Dict[str, Dict[str, List[str]]], - game_map: GameMap, # Added game_map - power_name: str, - start_unit_loc_full: str, - start_unit_type: str, - n: int = 3 + game_map: GameMap, # Added game_map + power_name: str, + start_unit_loc_full: str, + start_unit_type: str, + n: int = 3, ) -> List[Tuple[str, List[str]]]: """Finds up to N nearest enemy units, sorted by path length.""" - enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names) - - all_enemy_unit_locations_full: List[Tuple[str,str]] = [] # (loc_full, unit_str_full) + enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names) + + all_enemy_unit_locations_full: List[Tuple[str, str]] = [] # (loc_full, unit_str_full) # board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... } for p_name, unit_list_for_power in board_state.get("units", {}).items(): - if p_name != power_name: # If it's an enemy power - for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC" + if p_name != power_name: # If it's an enemy power + for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC" parts = unit_repr_from_state.split(" ") if len(parts) == 2: # unit_type_char = parts[0] # 'A' or 'F' - loc_full = parts[1] # 'PAR' or 'STP/SC' - + loc_full = parts[1] # 'PAR' or 'STP/SC' + # Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)" full_unit_str_with_power = get_unit_at_location(board_state, loc_full) - if full_unit_str_with_power: # Should find the unit if iteration is correct - all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power)) + if full_unit_str_with_power: # Should find the unit if iteration is correct + all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power)) for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full: target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full) - if '/' in target_enemy_loc_short: + if "/" in target_enemy_loc_short: target_enemy_loc_short = target_enemy_loc_short[:3] - if '/' not in target_enemy_loc_full: + if "/" not in target_enemy_loc_full: target_enemy_loc_short = target_enemy_loc_full[:3] else: target_enemy_loc_short = target_enemy_loc_full[:3] - + def is_specific_enemy_loc(loc_short: str, current_board_state: BoardState) -> bool: # Check if loc_short corresponds to target_enemy_loc_full return loc_short == target_enemy_loc_short @@ -219,8 +216,8 @@ def get_nearest_enemy_units( path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_specific_enemy_loc) if path_short_names: enemy_paths.append((enemy_unit_str, path_short_names)) - - enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length + + enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length return enemy_paths[:n] @@ -242,10 +239,10 @@ def get_nearest_uncontrolled_scs( """ results: List[Tuple[str, int, List[str]]] = [] - for sc_short in game_map.scs: # all SC province codes + for sc_short in game_map.scs: # all SC province codes controller = get_sc_controller(game_map, board_state, sc_short) if controller == power_name: - continue # already ours + continue # already ours # helper for BFS target test def is_target(loc_short: str, _state: BoardState) -> bool: @@ -260,9 +257,9 @@ def get_nearest_uncontrolled_scs( is_target, ) if not path: - continue # unreachable + continue # unreachable - distance = len(path) - 1 # moves needed + distance = len(path) - 1 # moves needed # skip distance 0 (same province) and 1 (adjacent) if distance <= 1: @@ -277,11 +274,11 @@ def get_nearest_uncontrolled_scs( def get_adjacent_territory_details( - game_map: GameMap, - board_state: BoardState, - unit_loc_full: str, # The location of the unit whose adjacencies we're checking - unit_type: str, # ARMY or FLEET of the unit at unit_loc_full - graph: Dict[str, Dict[str, List[str]]] + game_map: GameMap, + board_state: BoardState, + unit_loc_full: str, # The location of the unit whose adjacencies we're checking + unit_type: str, # ARMY or FLEET of the unit at unit_loc_full + graph: Dict[str, Dict[str, List[str]]], ) -> str: """Generates a string describing adjacent territories and units that can interact with them.""" output_lines: List[str] = [] @@ -291,37 +288,37 @@ def get_adjacent_territory_details( # For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts) # THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values. unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full) - if '/' in unit_loc_short: + if "/" in unit_loc_short: unit_loc_short = unit_loc_short[:3] - if '/' not in unit_loc_full: + if "/" not in unit_loc_full: unit_loc_short = unit_loc_full[:3] else: unit_loc_short = unit_loc_full[:3] - adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, []) + adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, []) - processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT + processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT - for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short + for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short # adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed - if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data + if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data continue processed_adj_provinces.add(adj_loc_short) - adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper() - if adj_loc_type == 'COAST' or adj_loc_type == 'LAND': - adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST' - elif adj_loc_type == 'WATER': - adj_loc_type_display = 'WATER' - else: # SHUT etc. + adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper() + if adj_loc_type == "COAST" or adj_loc_type == "LAND": + adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST" + elif adj_loc_type == "WATER": + adj_loc_type_display = "WATER" + else: # SHUT etc. adj_loc_type_display = adj_loc_type - + line = f" {adj_loc_short} ({adj_loc_type_display})" - + sc_controller = get_sc_controller(game_map, board_state, adj_loc_short) if sc_controller: line += f" SC Control: {sc_controller}" - + unit_in_adj_loc = get_unit_at_location(board_state, adj_loc_short) if unit_in_adj_loc: line += f" Units: {unit_in_adj_loc}" @@ -330,15 +327,14 @@ def get_adjacent_territory_details( # "Can support/move to" - Simplified: list units in *further* adjacent provinces # A true "can support/move to" would require checking possible orders of those further units. # further_adj_provinces are short names from the graph - further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \ - graph.get(adj_loc_short, {}).get('FLEET', []) - + further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", []) + supporting_units_info = [] processed_further_provinces = set() for further_adj_loc_short in further_adj_provinces_short: # further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3] # No conversion needed, it's already short - if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin + if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin continue if further_adj_loc_short in processed_further_provinces: continue @@ -354,14 +350,14 @@ def get_adjacent_territory_details( temp_unit = get_unit_at_location(board_state, fv_further) if temp_unit: unit_in_further_loc = temp_unit - break # Found a unit in one of the coasts/base - + break # Found a unit in one of the coasts/base + # if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short: # unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short) - + if unit_in_further_loc: supporting_units_info.append(unit_in_further_loc) - + if supporting_units_info: output_lines.append(f" => Can support/move to: {', '.join(sorted(list(set(supporting_units_info))))}") @@ -378,30 +374,30 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_ board_state: BoardState = game.get_state() game_map: GameMap = game.map graph = build_diplomacy_graph(game_map) - + final_context_lines: List[str] = [""] # Iterate through units that have orders (keys of possible_orders_for_power are unit locations) for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items(): unit_str_full = get_unit_at_location(board_state, unit_loc_full) - if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys - continue + if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys + continue - unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F' - unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET" + unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F' + unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET" loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3] loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper() - if loc_type_short == 'COAST' or loc_type_short == 'LAND': - loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST' + if loc_type_short == "COAST" or loc_type_short == "LAND": + loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST" else: loc_type_display = loc_type_short current_unit_lines: List[str] = [] current_unit_lines.append(f' ') - + # Unit Information section - current_unit_lines.append(' ') + current_unit_lines.append(" ") sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full) header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})" if sc_owner_at_loc == power_name: @@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_ header_content += f" (SC controlled by {sc_owner_at_loc})" current_unit_lines.append(f" {header_content}") current_unit_lines.append(f" Units present: {unit_str_full}") - current_unit_lines.append(' ') + current_unit_lines.append(" ") # Possible moves section - current_unit_lines.append(' ') + current_unit_lines.append(" ") current_unit_lines.append(" Possible moves:") for order_str in unit_specific_possible_orders: current_unit_lines.append(f" {order_str}") - current_unit_lines.append(' ') - + current_unit_lines.append(" ") + # Nearest enemy units section enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3) - current_unit_lines.append(' ') + current_unit_lines.append(" ") if enemy_units_info: current_unit_lines.append(" Nearest units (not ours):") for enemy_unit_str, enemy_path_short in enemy_units_info: - current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]") + current_unit_lines.append( + f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]" + ) else: current_unit_lines.append(" Nearest units (not ours): None found") - current_unit_lines.append(' ') + current_unit_lines.append(" ") # Nearest supply centers (not controlled by us) section uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3) - current_unit_lines.append(' ') + current_unit_lines.append(" ") if uncontrolled_scs_info: current_unit_lines.append(" Nearest supply centers (not controlled by us):") for sc_str, dist, sc_path_short in uncontrolled_scs_info: - current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]") + current_unit_lines.append( + f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]" + ) else: current_unit_lines.append(" Nearest supply centers (not controlled by us): None found") - current_unit_lines.append(' ') + current_unit_lines.append(" ") # Adjacent territories details section adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph) - current_unit_lines.append(' ') + current_unit_lines.append(" ") if adj_details_str: current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):") # Assuming adj_details_str is already formatted with newlines and indentation for its content # We might need to indent adj_details_str if it's a single block of text # For now, let's add a standard indent to each line of adj_details_str if it contains newlines - if '\n' in adj_details_str: - indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')]) + if "\n" in adj_details_str: + indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")]) current_unit_lines.append(indented_adj_details) else: - current_unit_lines.append(f" {adj_details_str}") + current_unit_lines.append(f" {adj_details_str}") else: - current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else - current_unit_lines.append(' ') - - current_unit_lines.append(' ') + current_unit_lines.append( + " Adjacent territories: None relevant or all are empty/uncontested by direct threats." + ) # Added more descriptive else + current_unit_lines.append(" ") + + current_unit_lines.append(" ") final_context_lines.extend(current_unit_lines) final_context_lines.append("") @@ -469,16 +471,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_ # Regex and tiny helpers # --------------------------------------------------------------------------- -import re from typing import Tuple, List, Dict, Optional, Any # ── order-syntax matchers ───────────────────────────────────────────────── _SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$") -_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW -_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$") -_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband +_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW +_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$") +_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband -def _is_hold_order(order: str) -> bool: # NEW + +def _is_hold_order(order: str) -> bool: # NEW return bool(_HOLD_RE.match(order.strip())) @@ -486,15 +488,19 @@ def _norm_power(name: str) -> str: """Trim & uppercase for reliable comparisons.""" return name.strip().upper() + def _is_simple_move(order: str) -> bool: return bool(_SIMPLE_MOVE_RE.match(order.strip())) + def _is_retreat_order(order: str) -> bool: return bool(_RETREAT_RE.match(order.strip())) + def _is_adjust_order(order: str) -> bool: return bool(_ADJUST_RE.match(order.strip())) + def _split_move(order: str) -> Tuple[str, str]: """Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only).""" unit_part, dest = order.split(" - ") @@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]: # Gather *all* friendly support orders for a given move # --------------------------------------------------------------------------- + def _all_support_examples( mover: str, dest: str, @@ -526,6 +533,7 @@ def _all_support_examples( return supports + def _all_support_hold_examples( holder: str, all_orders: Dict[str, List[str]], @@ -538,7 +546,7 @@ def _all_support_hold_examples( supports: List[str] = [] for loc, orders in all_orders.items(): - if holder.endswith(loc): # skip the holding unit + if holder.endswith(loc): # skip the holding unit continue for o in orders: if o.strip().endswith(target): @@ -550,6 +558,7 @@ def _all_support_hold_examples( # Province-type resolver (handles short codes, coasts, seas) # --------------------------------------------------------------------------- + def _province_type_display(game_map, prov_short: str) -> str: """ Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code. @@ -575,7 +584,7 @@ def _dest_occupancy_desc( board_state, our_power: str, ) -> str: - """ '(occupied by X)', '(occupied by X — you!)', or '(unoccupied)' """ + """'(occupied by X)', '(occupied by X — you!)', or '(unoccupied)'""" occupant: Optional[str] = None for full in game_map.loc_coasts.get(dest_short, [dest_short]): u = get_unit_at_location(board_state, full) @@ -593,6 +602,7 @@ def _dest_occupancy_desc( # Adjacent-territory lines (used by movement-phase builder) # --------------------------------------------------------------------------- + def _adjacent_territory_lines( graph, game_map, @@ -633,13 +643,9 @@ def _adjacent_territory_lines( pwr = unit_here.split(" ")[-1].strip("()") if pwr == our_power: friend_descr = unit_here.split(" (")[0] - lines.append( - f"{indent2}Support hold: {mover_descr} S {friend_descr}" - ) + lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}") else: - lines.append( - f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa" - ) + lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa") return lines @@ -648,6 +654,7 @@ def _adjacent_territory_lines( # Movement-phase generator (UNCHANGED LOGIC) # --------------------------------------------------------------------------- + def _generate_rich_order_context_movement( game: Any, power_name: str, @@ -672,18 +679,14 @@ def _generate_rich_order_context_movement( if _norm_power(unit_power) != me: continue # Skip units that aren’t ours - mover_descr, _ = _split_move( - f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}" - ) + mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}") prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3] prov_type_disp = _province_type_display(game_map, prov_short) sc_tag = " (SC)" if prov_short in game_map.scs else "" owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None" - owner_line = ( - f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}" - ) + owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}" ind = " " block: List[str] = [f""] @@ -693,18 +696,16 @@ def _generate_rich_order_context_movement( # ----- Adjacent territories ----- block.append("# Adjacent territories:") - block.extend( - _adjacent_territory_lines( - graph, game_map, board_state, - unit_loc_full, mover_descr, power_name - ) - ) + block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name)) # ----- Nearest enemy units ----- block.append("# Nearest units (not ours):") enemies = get_nearest_enemy_units( - board_state, graph, game_map, - power_name, unit_loc_full, + board_state, + graph, + game_map, + power_name, + unit_loc_full, "ARMY" if mover_descr.startswith("A") else "FLEET", n=3, ) @@ -715,8 +716,11 @@ def _generate_rich_order_context_movement( # ----- Nearest uncontrolled SCs ----- block.append("# Nearest supply centers (not controlled by us):") scs = get_nearest_uncontrolled_scs( - game_map, board_state, graph, - power_name, unit_loc_full, + game_map, + board_state, + graph, + power_name, + unit_loc_full, "ARMY" if mover_descr.startswith("A") else "FLEET", n=3, ) @@ -729,7 +733,7 @@ def _generate_rich_order_context_movement( block.append(f"# Possible {mover_descr} unit movements & supports:") simple_moves = [o for o in orders if _is_simple_move(o)] - hold_orders = [o for o in orders if _is_hold_order(o)] # NEW + hold_orders = [o for o in orders if _is_hold_order(o)] # NEW if not simple_moves and not hold_orders: block.append(f"{ind}None") @@ -737,22 +741,19 @@ def _generate_rich_order_context_movement( # ---- Moves (same behaviour as before) ---- for mv in simple_moves: mover, dest = _split_move(mv) - occ = _dest_occupancy_desc( - dest.split("/")[0][:3], game_map, board_state, power_name - ) + occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name) block.append(f"{ind}{mv} {occ}") for s in _all_support_examples(mover, dest, possible_orders_for_power): - block.append(f"{ind*2}Available Support: {s}") + block.append(f"{ind * 2}Available Support: {s}") # ---- Holds (new) ---- for hd in hold_orders: - holder = hd.split(" H")[0] # e.g., 'F DEN' + holder = hd.split(" H")[0] # e.g., 'F DEN' block.append(f"{ind}{hd}") for s in _all_support_hold_examples(holder, possible_orders_for_power): - block.append(f"{ind*2}Available Support: {s}") - + block.append(f"{ind * 2}Available Support: {s}") block.append(f"") blocks.append("\n".join(block)) @@ -764,6 +765,7 @@ def _generate_rich_order_context_movement( # Retreat-phase builder – echo orders verbatim, no tags # --------------------------------------------------------------------------- + def _generate_rich_order_context_retreat( game: Any, power_name: str, @@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat( return "\n".join(lines) if lines else "(No dislodged units)" - # --------------------------------------------------------------------------- # Adjustment-phase builder – summary line + orders, no WAIVEs, no tags # --------------------------------------------------------------------------- + def _generate_rich_order_context_adjustment( game: Any, power_name: str, @@ -800,12 +802,12 @@ def _generate_rich_order_context_adjustment( * No wrapper tags. """ board_state = game.get_state() - sc_owned = len(board_state.get("centers", {}).get(power_name, [])) - units_num = len(board_state.get("units", {}).get(power_name, [])) - delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands + sc_owned = len(board_state.get("centers", {}).get(power_name, [])) + units_num = len(board_state.get("units", {}).get(power_name, [])) + delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands # ----- summary line ---------------------------------------------------- - if delta > 0: + if delta > 0: summary = f"Builds available: {delta}" elif delta < 0: summary = f"Disbands required: {-delta}" @@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment( return "\n".join(lines) if len(lines) > 1 else summary - - - # --------------------------------------------------------------------------- # Phase-dispatch wrapper (public entry point) # --------------------------------------------------------------------------- + def generate_rich_order_context( game: Any, power_name: str, @@ -846,22 +846,14 @@ def generate_rich_order_context( phase_type = game.current_short_phase[-1] - if phase_type == "M": # Movement - return _generate_rich_order_context_movement( - game, power_name, possible_orders_for_power - ) + if phase_type == "M": # Movement + return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power) - if phase_type == "R": # Retreat - return _generate_rich_order_context_retreat( - game, power_name, possible_orders_for_power - ) + if phase_type == "R": # Retreat + return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power) - if phase_type == "A": # Adjustment (build / disband) - return _generate_rich_order_context_adjustment( - game, power_name, possible_orders_for_power - ) + if phase_type == "A": # Adjustment (build / disband) + return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power) # Fallback – treat unknown formats as movement - return _generate_rich_order_context_movement( - game, power_name, possible_orders_for_power - ) + return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power) diff --git a/ai_diplomacy/prompt_constructor.py b/ai_diplomacy/prompt_constructor.py index d77cffb..07d4f25 100644 --- a/ai_diplomacy/prompt_constructor.py +++ b/ai_diplomacy/prompt_constructor.py @@ -11,7 +11,6 @@ from .possible_order_context import ( generate_rich_order_context, generate_rich_order_context_xml, ) -import os from .game_history import GameHistory # Assuming GameHistory is correctly importable # placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only @@ -230,4 +229,3 @@ def construct_order_generation_prompt( logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...") return final_prompt - diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index b1808b9..e47643f 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -16,7 +16,7 @@ from ..config import config if TYPE_CHECKING: from .clients import BaseModelClient # If DiplomacyAgent is used for type hinting for an 'agent' parameter: - # from .agent import DiplomacyAgent + # from .agent import DiplomacyAgent logger = logging.getLogger("utils") logger.setLevel(logging.INFO) @@ -32,12 +32,12 @@ def atomic_write_json(data: dict, filepath: str): dir_name = os.path.dirname(filepath) if dir_name: os.makedirs(dir_name, exist_ok=True) - + # Write to a temporary file in the same directory temp_filepath = f"{filepath}.tmp.{os.getpid()}" - with open(temp_filepath, 'w', encoding='utf-8') as f: + with open(temp_filepath, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) - + # Atomically rename the temporary file to the final destination os.rename(temp_filepath, filepath) except Exception as e: @@ -55,15 +55,15 @@ def assign_models_to_powers() -> Dict[str, str]: Example usage: define which model each power uses. Return a dict: { power_name: model_id, ... } POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY'] - Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini, - claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219 - gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25, + Models supported: o3-mini, o4-mini, o3, gpt-4o, gpt-4o-mini, + claude-opus-4-20250514, claude-sonnet-4-20250514, claude-3-5-haiku-20241022, claude-3-5-sonnet-20241022, claude-3-7-sonnet-20250219 + gemini-2.0-flash, gemini-2.5-flash-preview-04-17, gemini-2.5-pro-preview-03-25, deepseek-chat, deepseek-reasoner openrouter-meta-llama/llama-3.3-70b-instruct, openrouter-qwen/qwen3-235b-a22b, openrouter-microsoft/phi-4-reasoning-plus:free, - openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free, + openrouter-deepseek/deepseek-prover-v2:free, openrouter-meta-llama/llama-4-maverick:free, openrouter-nvidia/llama-3.3-nemotron-super-49b-v1:free, openrouter-google/gemma-3-12b-it:free, openrouter-google/gemini-2.5-flash-preview-05-20 """ - + # POWER MODELS """ return { @@ -77,13 +77,13 @@ def assign_models_to_powers() -> Dict[str, str]: } """ # TEST MODELS - + return { "AUSTRIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct", "ENGLAND": "openrouter-mistralai/mistral-small-3.2-24b-instruct", "FRANCE": "openrouter-mistralai/mistral-small-3.2-24b-instruct", "GERMANY": "openrouter-mistralai/mistral-small-3.2-24b-instruct", - "ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct", + "ITALY": "openrouter-mistralai/mistral-small-3.2-24b-instruct", "RUSSIA": "openrouter-mistralai/mistral-small-3.2-24b-instruct", "TURKEY": "openrouter-mistralai/mistral-small-3.2-24b-instruct", } @@ -92,30 +92,26 @@ def assign_models_to_powers() -> Dict[str, str]: def get_special_models() -> Dict[str, str]: """ Define models for special purposes like phase summaries and formatting. - + These can be overridden via environment variables: - AI_DIPLOMACY_NARRATIVE_MODEL: Model for phase summaries (default: "o3") - AI_DIPLOMACY_FORMATTER_MODEL: Model for JSON formatting (default: "google/gemini-2.5-flash-lite-preview-06-17") - + Returns: dict: { "phase_summary": model for generating narrative phase summaries, "formatter": model for formatting natural language to JSON } - + Examples: # Use Claude for phase summaries export AI_DIPLOMACY_NARRATIVE_MODEL="claude-3-5-sonnet-20241022" - + # Use a different Gemini model for formatting export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash" """ - return { - "phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, - "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL - } - - + return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL} + def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]: """ @@ -132,7 +128,7 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]: async def get_valid_orders( game: Game, - client, # BaseModelClient instance + client, # BaseModelClient instance board_state, power_name: str, possible_orders: Dict[str, List[str]], @@ -170,20 +166,24 @@ async def get_valid_orders( ) invalid_info: list[str] = [] - valid: list[str] = [] + valid: list[str] = [] invalid: list[str] = [] # ── 2. Type check ────────────────────────────────────────── if not isinstance(raw_orders, list): - logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", - power_name, raw_orders) + logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders) model_error_stats[client.model_name]["order_decoding_errors"] += 1 return {"valid": client.fallback_orders(possible_orders), "invalid": []} # ── 3. Round-trip validation with engine ─────────────────── CODE_TO_ENGINE = { - "AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE", - "GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY", + "AUT": "AUSTRIA", + "ENG": "ENGLAND", + "FRA": "FRANCE", + "GER": "GERMANY", + "ITA": "ITALY", + "RUS": "RUSSIA", + "TUR": "TURKEY", } engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name] @@ -202,16 +202,16 @@ async def get_valid_orders( game.set_orders(engine_power, [upper]) normed = game.get_orders(engine_power) - if normed: # accepted + if normed: # accepted valid.append(normed[0]) - else: # rejected + else: # rejected invalid.append(upper) invalid_info.append(f"Order '{move}' is invalid for {power_name}") game.clear_orders(engine_power) # clean slate for main engine flow # ── 4. Legacy logging & stats updates ────────────────────── - if invalid_info: # at least one bad move + if invalid_info: # at least one bad move logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info)) model_error_stats[client.model_name]["order_decoding_errors"] += 1 logger.debug("[%s] Some orders invalid, using fallback.", power_name) @@ -226,8 +226,6 @@ async def get_valid_orders( return {"valid": valid, "invalid": invalid} - - def normalize_and_compare_orders( issued_orders: Dict[str, List[str]], accepted_orders_dict: Dict[str, List[str]], @@ -325,18 +323,18 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str: 3. Elif *prompts_dir* is provided → join prompts_dir + filename. 4. Otherwise → join the package’s default prompts dir. """ - if os.path.isabs(filename): # rule 1 + if os.path.isabs(filename): # rule 1 prompt_path = filename - elif os.path.dirname(filename): # rule 2 (has slash) + elif os.path.dirname(filename): # rule 2 (has slash) # If it's a relative path with directory, join with prompts_dir if provided if prompts_dir: prompt_path = os.path.join(prompts_dir, filename) else: default_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_path = os.path.join(default_dir, filename) - elif prompts_dir: # rule 3 + elif prompts_dir: # rule 3 prompt_path = os.path.join(prompts_dir, filename) - else: # rule 4 + else: # rule 4 default_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_path = os.path.join(default_dir, filename) @@ -348,16 +346,14 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str: return "" - - # == New LLM Response Logging Function == def log_llm_response( log_file_path: str, model_name: str, - power_name: Optional[str], # Optional for non-power-specific calls like summary + power_name: Optional[str], # Optional for non-power-specific calls like summary phase: str, response_type: str, - raw_input_prompt: str, # Added new parameter for the raw input + raw_input_prompt: str, # Added new parameter for the raw input raw_response: str, success: str, # Changed from bool to str ): @@ -365,8 +361,8 @@ def log_llm_response( try: # Ensure the directory exists log_dir = os.path.dirname(log_file_path) - if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename) - os.makedirs(log_dir, exist_ok=True) + if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename) + os.makedirs(log_dir, exist_ok=True) # Check if file exists and has content to determine if we need headers file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0 @@ -374,34 +370,38 @@ def log_llm_response( with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile: # Added "raw_input" to fieldnames fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, - quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines - escapechar='\\') # Use backslash for escaping + writer = csv.DictWriter( + csvfile, + fieldnames=fieldnames, + quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines + escapechar="\\", + ) # Use backslash for escaping if not file_exists: writer.writeheader() # Write header only if file is new - writer.writerow({ - "model": model_name, - "power": power_name if power_name else "game", # Use 'game' if no specific power - "phase": phase, - "response_type": response_type, - "raw_input": raw_input_prompt, # Added raw_input to the row - "raw_response": raw_response, - "success": success, - }) + writer.writerow( + { + "model": model_name, + "power": power_name if power_name else "game", # Use 'game' if no specific power + "phase": phase, + "response_type": response_type, + "raw_input": raw_input_prompt, # Added raw_input to the row + "raw_response": raw_response, + "success": success, + } + ) except Exception as e: logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True) - async def run_llm_and_log( client: "BaseModelClient", prompt: str, - log_file_path: str, # Kept for context, but not used for logging here + log_file_path: str, # Kept for context, but not used for logging here power_name: Optional[str], # Kept for context, but not used for logging here - phase: str, # Kept for context, but not used for logging here - response_type: str, # Kept for context, but not used for logging here + phase: str, # Kept for context, but not used for logging here + response_type: str, # Kept for context, but not used for logging here temperature: float = 0.0, *, attempts: int = 5, @@ -428,43 +428,34 @@ async def run_llm_and_log( except Exception as e: if attempt == attempts - 1: logger.error( - f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} " - f"in phase {phase}: {e}", + f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}", exc_info=True, ) # Back-off before the next attempt (unless this was the last) - delay = backoff_base * (backoff_factor ** attempt) + random.uniform(0, jitter) + delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter) await asyncio.sleep(delay) return raw_response - -# This generates a few lines of random alphanum chars to inject into the -# system prompt. This lets us use temp=0 while still getting variation +# This generates a few lines of random alphanum chars to inject into the +# system prompt. This lets us use temp=0 while still getting variation # between trials. -# Temp=0 is important for better performance on deciding moves, and to +# Temp=0 is important for better performance on deciding moves, and to # ensure valid json outputs. def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80): - # Generate x lines of y random alphanumeric characters - seed_lines = [ - ''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) - for _ in range(n_lines) - ] - random_seed_block = ( - "\n" + - "\n".join(seed_lines) + - "\n" - ) - return random_seed_block + # Generate x lines of y random alphanumeric characters + seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)] + random_seed_block = "\n" + "\n".join(seed_lines) + "\n" + return random_seed_block def get_prompt_path(prompt_name: str) -> str: """Get the appropriate prompt path based on USE_UNFORMATTED_PROMPTS setting. - + Args: prompt_name: Base name of the prompt file (e.g., "conversation_instructions.txt") - + Returns: str: Either "unformatted/{prompt_name}" or just "{prompt_name}" """ @@ -472,36 +463,37 @@ def get_prompt_path(prompt_name: str) -> str: return f"unformatted/{prompt_name}" else: return prompt_name - -def normalize_recipient_name(recipient: str) -> str: - """Normalize recipient names to handle LLM typos and abbreviations.""" - if not recipient: - return recipient - - recipient = recipient.upper().strip() - - # Handle common LLM typos and abbreviations found in data - name_mapping = { - 'EGMANY': 'GERMANY', - 'GERMAN': 'GERMANY', - 'UK': 'ENGLAND', - 'BRIT': 'ENGLAND', - 'ENGLAND': 'ENGLAND', # Keep as-is - 'FRANCE': 'FRANCE', # Keep as-is - 'GERMANY': 'GERMANY', # Keep as-is - 'ITALY': 'ITALY', # Keep as-is - 'AUSTRIA': 'AUSTRIA', # Keep as-is - 'RUSSIA': 'RUSSIA', # Keep as-is - 'TURKEY': 'TURKEY', # Keep as-is - 'Germany': 'GERMANY', - 'England': 'ENGLAND', - 'France': 'FRANCE', - 'Italy': 'ITALY', - 'Russia': 'RUSSIA', - 'Austria': 'AUSTRIA', - 'Turkey': 'TURKEY', - } - - normalized = name_mapping.get(recipient, recipient) - return normalized + +def normalize_recipient_name(recipient: str) -> str: + """Normalize recipient names to handle LLM typos and abbreviations.""" + if not recipient: + return recipient + + recipient = recipient.upper().strip() + + # Handle common LLM typos and abbreviations found in data + name_mapping = { + "EGMANY": "GERMANY", + "GERMAN": "GERMANY", + "UK": "ENGLAND", + "BRIT": "ENGLAND", + "ENGLAND": "ENGLAND", # Keep as-is + "FRANCE": "FRANCE", # Keep as-is + "GERMANY": "GERMANY", # Keep as-is + "ITALY": "ITALY", # Keep as-is + "AUSTRIA": "AUSTRIA", # Keep as-is + "RUSSIA": "RUSSIA", # Keep as-is + "TURKEY": "TURKEY", # Keep as-is + "Germany": "GERMANY", + "England": "ENGLAND", + "France": "FRANCE", + "Italy": "ITALY", + "Russia": "RUSSIA", + "Austria": "AUSTRIA", + "Turkey": "TURKEY", + } + + normalized = name_mapping.get(recipient, recipient) + + return normalized