diff --git a/.vscode/launch.json b/.vscode/launch.json index ba7f205..b0d66bb 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -105,7 +105,7 @@ "--game-id", "${input:gameId}", "--log-level", - "INFO" + "DEBUG" ], "console": "integratedTerminal", "justMyCode": true, @@ -138,6 +138,7 @@ "type": "pickString", "description": "Select the AI model to use", "options": [ + "gemini-2.5-flash-lite-preview-06-17", "gpt-3.5-turbo", "gpt-4", "claude-3-haiku", diff --git a/ai_diplomacy/agent.py b/ai_diplomacy/agent.py index 2ae7313..d24a472 100644 --- a/ai_diplomacy/agent.py +++ b/ai_diplomacy/agent.py @@ -5,6 +5,7 @@ import re import json_repair import json5 +from ai_diplomacy.game_history import GameHistory from diplomacy.engine.game import Game # More forgiving JSON parser # Assuming BaseModelClient is importable from clients.py in the same directory @@ -1150,6 +1151,9 @@ class DiplomacyAgent: game_history: "GameHistory", log_file_path: str | Path, ): + assert GameHistory is not None, ( + "GameHistory object must be initialized before calling analyze_phase_and_update_state" + ) """Analyzes the outcome of the last phase and updates goals/relationships using the LLM.""" # Use self.power_name internally power_name = self.power_name @@ -1169,7 +1173,7 @@ class DiplomacyAgent: return # Get previous phase safely from history - if not game_history or not game_history.phases: + if not game_history.phases: logger.warning( f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update." ) @@ -1254,9 +1258,8 @@ class DiplomacyAgent: log_entry_success = "FALSE" # Default update_data = None # Initialize - if ( - response is not None and response.strip() - ): # Check if response is not None and not just whitespace + # Check if response is not None and not just whitespace + if response is not None and response.strip(): try: update_data = self._extract_json_from_text(response) logger.debug( diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index a0fb6dc..e79af5d 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -15,7 +15,6 @@ from anthropic import AsyncAnthropic import google.generativeai as genai -from diplomacy.engine.message import GLOBAL from .game_history import GameHistory from .utils import ( load_prompt, @@ -25,7 +24,6 @@ from .utils import ( # Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible # from .agent import DiplomacyAgent -from .possible_order_context import generate_rich_order_context from .prompt_constructor import ( construct_order_generation_prompt, build_context_prompt, @@ -113,7 +111,6 @@ class BaseModelClient: raw_response = await run_llm_and_log( client=self, prompt=prompt, - log_file_path=log_file_path, power_name=power_name, phase=phase, response_type="order", # Context for run_llm_and_log's own error logging @@ -553,7 +550,6 @@ class BaseModelClient: raw_response = await run_llm_and_log( client=self, prompt=prompt, - log_file_path=log_file_path, power_name=power_name, phase=game_phase, # Use game_phase for logging response_type="plan_reply", # Changed from 'plan' to avoid confusion @@ -604,7 +600,6 @@ class BaseModelClient: raw_response = await run_llm_and_log( client=self, prompt=raw_input_prompt, - log_file_path=log_file_path, power_name=power_name, phase=game_phase, response_type="negotiation", # For run_llm_and_log's internal context @@ -829,7 +824,6 @@ class BaseModelClient: raw_plan_response = await run_llm_and_log( client=self, # Pass self (the client instance) prompt=full_prompt, - log_file_path=log_file_path, power_name=power_name, phase=game.current_short_phase, response_type="plan_generation", # More specific type for run_llm_and_log context @@ -1328,4 +1322,3 @@ def get_visible_messages_for_power(conversation_messages, power_name): ): visible.append(msg) return visible # already in chronological order if appended that way - diff --git a/ai_diplomacy/diary_logic.py b/ai_diplomacy/diary_logic.py index 407739d..a28a702 100644 --- a/ai_diplomacy/diary_logic.py +++ b/ai_diplomacy/diary_logic.py @@ -3,7 +3,7 @@ import logging import re from typing import TYPE_CHECKING, Optional -from .utils import run_llm_and_log, log_llm_response, load_prompt +from .utils import run_llm_and_log, log_llm_response if TYPE_CHECKING: from diplomacy import Game @@ -12,6 +12,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None: + """A local copy of the helper from agent.py to avoid circular imports.""" + import os + + try: + if prompts_dir: + filepath = os.path.join(prompts_dir, filename) + else: + current_dir = os.path.dirname(os.path.abspath(__file__)) + default_prompts_dir = os.path.join(current_dir, "prompts") + filepath = os.path.join(default_prompts_dir, filename) + + with open(filepath, "r", encoding="utf-8") as f: + return f.read() + except Exception as e: + logger.error(f"Error loading prompt file {filepath}: {e}") + return None + + async def run_diary_consolidation( agent: "DiplomacyAgent", game: "Game", @@ -23,19 +42,32 @@ async def run_diary_consolidation( Consolidate older diary entries while keeping recent ones. This is the logic moved from the DiplomacyAgent class. """ - logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries") + logger.info( + f"[{agent.power_name}] CONSOLIDATION START — " + f"{len(agent.full_private_diary)} total full entries" + ) - full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")] + full_entries = [ + e + for e in agent.full_private_diary + if not e.startswith("[CONSOLIDATED HISTORY]") + ] if len(full_entries) <= entries_to_keep_unsummarized: agent.private_diary = list(agent.full_private_diary) - logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation") + logger.info( + f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — " + "skipping consolidation" + ) return boundary_entry = full_entries[-entries_to_keep_unsummarized] match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry) if not match: - logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation") + logger.error( + f"[{agent.power_name}] Could not parse year from boundary entry; " + "aborting consolidation" + ) agent.private_diary = list(agent.full_private_diary) return @@ -46,19 +78,37 @@ async def run_diary_consolidation( m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry) return int(m.group(1)) if m else None - entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)] - entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)] + entries_to_summarize = [ + e + for e in full_entries + if (_entry_year(e) is not None and _entry_year(e) < cutoff_year) + ] + entries_to_keep = [ + e + for e in full_entries + if (_entry_year(e) is None or _entry_year(e) >= cutoff_year) + ] - logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim") + logger.info( + f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; " + f"keeping {len(entries_to_keep)} recent entries verbatim" + ) if not entries_to_summarize: agent.private_diary = list(agent.full_private_diary) - logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged") + logger.warning( + f"[{agent.power_name}] No eligible entries to summarise; " + "context diary left unchanged" + ) return - prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) + prompt_template = _load_prompt_file( + "diary_consolidation_prompt.txt", prompts_dir=prompts_dir + ) if not prompt_template: - logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting") + logger.error( + f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting" + ) return prompt = prompt_template.format( @@ -87,14 +137,23 @@ async def run_diary_consolidation( new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}" agent.private_diary = [new_summary_entry] + entries_to_keep success_flag = "TRUE" - logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now") + logger.info( + f"[{agent.power_name}] Consolidation complete — " + f"{len(agent.private_diary)} context entries now" + ) except Exception as exc: - logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True) + logger.error( + f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True + ) finally: log_llm_response( log_file_path=log_file_path, - model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name), + model_name=( + consolidation_client.model_name + if consolidation_client is not None + else agent.client.model_name + ), power_name=agent.power_name, phase=game.current_short_phase, response_type="diary_consolidation", @@ -102,3 +161,4 @@ async def run_diary_consolidation( raw_response=raw_response, success=success_flag, ) + diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py index b37e240..6dee721 100644 --- a/ai_diplomacy/game_history.py +++ b/ai_diplomacy/game_history.py @@ -148,7 +148,6 @@ class GameHistory: return {} return self.phases[-1].plans - # NEW METHOD def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str: current_phase: Optional[Phase] = None for phase_obj in self.phases: diff --git a/bot_client/config.py b/bot_client/config.py index 6c361d6..f5ed074 100644 --- a/bot_client/config.py +++ b/bot_client/config.py @@ -1,4 +1,4 @@ -from os.path import exists +import datetime from pydantic_settings import BaseSettings from pathlib import Path import warnings @@ -6,16 +6,16 @@ import warnings class Configuration(BaseSettings): DEBUG: bool = False - log_file_path: Path = Path("./logs/logs.txt") + log_file_path: Path DEEPSEEK_API_KEY: str | None = None OPENAI_API_KEY: str | None = None ANTHROPIC_API_KEY: str | None = None GEMINI_API_KEY: str | None = None OPENROUTER_API_KEY: str | None = None - def __init__(self, **kwargs): + def __init__(self, power_name, **kwargs): super().__init__(**kwargs) - + self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt") # Make the path absolute, gets rid of weirdness of calling this in different places self.log_file_path = self.log_file_path.resolve() self.log_file_path.parent.mkdir(parents=True, exist_ok=True) @@ -42,11 +42,7 @@ class Configuration(BaseSettings): """Override to check for empty API keys at access time""" value = super().__getattribute__(name) - if name.endswith("_KEY") and ( - not value or (isinstance(value, str) and len(value) == 0) - ): - raise ValueError( - f"API key '{name}' is not set or is empty. Please configure it before use." - ) + if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)): + raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.") return value diff --git a/bot_client/lm_game_websocket.py b/bot_client/lm_game_websocket.py index ec438ca..7ef86ce 100644 --- a/bot_client/lm_game_websocket.py +++ b/bot_client/lm_game_websocket.py @@ -30,12 +30,9 @@ from websocket_diplomacy_client import ( from models import ( SignInRequest, CreateGameRequest, - JoinGameRequest, SetOrdersRequest, ListGamesRequest, - ProcessGameRequest, serialize_message, - parse_message, ) import sys @@ -64,9 +61,7 @@ logging.getLogger("root").setLevel(logging.WARNING) def parse_arguments(): - parser = argparse.ArgumentParser( - description="Run a Diplomacy game simulation with WebSocket server connection." - ) + parser = argparse.ArgumentParser(description="Run a Diplomacy game simulation with WebSocket server connection.") parser.add_argument( "--hostname", type=str, @@ -137,9 +132,7 @@ def parse_arguments(): return parser.parse_args() -async def join_powers_for_testing( - client: WebSocketDiplomacyClient, power_model_map: dict -): +async def join_powers_for_testing(client: WebSocketDiplomacyClient, power_model_map: dict): """ Join multiple powers in the same game for testing purposes. This simulates having multiple AI players in one game. @@ -150,17 +143,13 @@ async def join_powers_for_testing( for power_name in power_names[1:]: try: logger.info(f"Attempting to join power {power_name}") - await client.channel.join_game( - game_id=client.game_id, power_name=power_name - ) + await client.channel.join_game(game_id=client.game_id, power_name=power_name) logger.info(f"Successfully joined {power_name}") except Exception as e: logger.warning(f"Could not join {power_name}: {e}") -async def create_or_join_game( - client: WebSocketDiplomacyClient, args, power_model_map: dict -): +async def create_or_join_game(client: WebSocketDiplomacyClient, args, power_model_map: dict): """ Create a new game or join an existing one based on arguments. """ @@ -171,9 +160,7 @@ async def create_or_join_game( # List available games first to see what's available try: games = await client.list_games() - logger.info( - f"Available games: {[g.get('game_id', 'unknown') for g in games]}" - ) + logger.info(f"Available games: {[g.get('game_id', 'unknown') for g in games]}") except Exception as e: logger.warning(f"Could not list games: {e}") @@ -189,19 +176,13 @@ async def create_or_join_game( logger.info("Creating new game") # Get the first power to control - first_power = ( - list(power_model_map.keys())[0] - if not args.create_multi_power_game - else None - ) + first_power = list(power_model_map.keys())[0] if not args.create_multi_power_game else None game = await client.create_game( map_name="standard", rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"], power_name=first_power, - n_controls=7 - if not args.create_multi_power_game - else 1, # Lower requirement for testing + n_controls=7 if not args.create_multi_power_game else 1, # Lower requirement for testing deadline=None, # No time pressure for AI testing ) @@ -215,21 +196,17 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args): """ Demonstrate how to use typed WebSocket messages. This shows the proper way to construct messages according to the protocol. - + NOTE: This is for demonstration only - the current client implementation uses the diplomacy library's built-in networking which abstracts the raw protocol. """ - + logger.info("=== Typed Message Examples ===") - + # Example 1: Authentication Request - sign_in_request = SignInRequest( - request_id=str(uuid.uuid4()), - username=args.username, - password=args.password - ) + sign_in_request = SignInRequest(request_id=str(uuid.uuid4()), username=args.username, password=args.password) logger.info(f"SignIn request: {serialize_message(sign_in_request)}") - + # Example 2: Create Game Request create_game_request = CreateGameRequest( request_id=str(uuid.uuid4()), @@ -237,29 +214,25 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args): map_name="standard", rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"], n_controls=1, - power_name="FRANCE" + power_name="FRANCE", ) logger.info(f"CreateGame request: {serialize_message(create_game_request)}") - + # Example 3: List Games Request - list_games_request = ListGamesRequest( - request_id=str(uuid.uuid4()), - token="example_auth_token_here", - include_protected=False - ) + list_games_request = ListGamesRequest(request_id=str(uuid.uuid4()), token="example_auth_token_here", include_protected=False) logger.info(f"ListGames request: {serialize_message(list_games_request)}") - + # Example 4: Set Orders Request (would be used in game loop) - if hasattr(client, 'game_id') and hasattr(client, 'token'): + if hasattr(client, "game_id") and hasattr(client, "token"): set_orders_request = SetOrdersRequest( request_id=str(uuid.uuid4()), token=client.token or "example_token", game_id=client.game_id or "example_game_id", game_role="FRANCE", - orders=["A PAR H", "F BRE H", "A MAR H"] + orders=["A PAR H", "F BRE H", "A MAR H"], ) logger.info(f"SetOrders request: {serialize_message(set_orders_request)}") - + logger.info("=== End Typed Message Examples ===") @@ -282,9 +255,7 @@ async def main(): logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs") start_whole = time.time() - model_error_stats = defaultdict( - lambda: {"conversation_errors": 0, "order_decoding_errors": 0} - ) + model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0}) # Determine the result folder based on a timestamp timestamp_str = time.strftime("%Y%m%d_%H%M%S") @@ -306,9 +277,7 @@ async def main(): # File paths manifesto_path = f"{result_folder}/game_manifesto.txt" - game_file_path = ( - args.output if args.output else f"{result_folder}/lmvsgame_websocket.json" - ) + game_file_path = args.output if args.output else f"{result_folder}/lmvsgame_websocket.json" overview_file_path = f"{result_folder}/overview.jsonl" llm_log_file_path = f"{result_folder}/llm_responses.csv" @@ -325,9 +294,7 @@ async def main(): ] provided_models = [name.strip() for name in args.models.split(",")] if len(provided_models) != len(powers_order): - logger.error( - f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting." - ) + logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting.") return power_model_map = dict(zip(powers_order, provided_models)) else: @@ -377,32 +344,20 @@ async def main(): client_obj = load_model_client(model_id) agent = DiplomacyAgent(power_name=power_name, client=client_obj) agents[power_name] = agent - logger.info( - f"Preparing initialization task for {power_name} with model {model_id}" - ) - initialization_tasks.append( - initialize_agent_state_ext( - agent, client.game, game_history, llm_log_file_path - ) - ) + logger.info(f"Preparing initialization task for {power_name} with model {model_id}") + initialization_tasks.append(initialize_agent_state_ext(agent, client.game, game_history, llm_log_file_path)) except Exception as e: logger.error( f"Failed to create agent for {power_name} with model {model_id}: {e}", exc_info=True, ) else: - logger.info( - f"Skipping agent initialization for {power_name} (no model or eliminated)" - ) + logger.info(f"Skipping agent initialization for {power_name} (no model or eliminated)") # Run initializations concurrently if initialization_tasks: - logger.info( - f"Running {len(initialization_tasks)} agent initializations concurrently..." - ) - initialization_results = await asyncio.gather( - *initialization_tasks, return_exceptions=True - ) + logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...") + initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True) initialized_powers = list(agents.keys()) for i, result in enumerate(initialization_results): @@ -414,9 +369,7 @@ async def main(): exc_info=result, ) else: - logger.info( - f"Successfully initialized agent state for {power_name}." - ) + logger.info(f"Successfully initialized agent state for {power_name}.") # Main game loop all_phase_relationships = {} @@ -433,9 +386,7 @@ async def main(): game_history.add_phase(current_phase) current_short_phase = client.get_current_short_phase() - logger.info( - f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)" - ) + logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)") # Prevent unbounded simulation based on year year_str = current_phase[1:5] @@ -447,9 +398,7 @@ async def main(): # Negotiations for movement phases if client.get_current_short_phase().endswith("M"): if args.num_negotiation_rounds > 0: - logger.info( - f"Running {args.num_negotiation_rounds} rounds of negotiations..." - ) + logger.info(f"Running {args.num_negotiation_rounds} rounds of negotiations...") game_history = await conduct_negotiations( client.game, # Pass the NetworkGame object agents, @@ -459,9 +408,7 @@ async def main(): log_file_path=llm_log_file_path, ) else: - logger.info( - "Skipping negotiation phase as num_negotiation_rounds=0" - ) + logger.info("Skipping negotiation phase as num_negotiation_rounds=0") # Planning phase (if enabled) if args.planning_phase: @@ -475,29 +422,19 @@ async def main(): ) # Generate negotiation diary entries - logger.info( - f"Generating negotiation diary entries for phase {current_short_phase}..." - ) - active_powers_for_neg_diary = [ - p for p in agents.keys() if not client.get_power(p).is_eliminated() - ] + logger.info(f"Generating negotiation diary entries for phase {current_short_phase}...") + active_powers_for_neg_diary = [p for p in agents.keys() if not client.get_power(p).is_eliminated()] neg_diary_tasks = [] for power_name, agent in agents.items(): if not client.get_power(power_name).is_eliminated(): - neg_diary_tasks.append( - agent.generate_negotiation_diary_entry( - client.game, game_history, llm_log_file_path - ) - ) + neg_diary_tasks.append(agent.generate_negotiation_diary_entry(client.game, game_history, llm_log_file_path)) if neg_diary_tasks: await asyncio.gather(*neg_diary_tasks, return_exceptions=True) # AI Decision Making: Get orders for each controlled power logger.info("Getting orders from agents...") - active_powers_for_orders = [ - p for p in agents.keys() if not client.get_power(p).is_eliminated() - ] + active_powers_for_orders = [p for p in agents.keys() if not client.get_power(p).is_eliminated()] order_tasks = [] order_power_names = [] @@ -505,35 +442,23 @@ async def main(): for power_name, agent in agents.items(): if client.get_power(power_name).is_eliminated(): - logger.debug( - f"Skipping order generation for eliminated power {power_name}." - ) + logger.debug(f"Skipping order generation for eliminated power {power_name}.") continue # Diagnostic logging - logger.info( - f"--- Diagnostic Log for {power_name} in phase {current_phase} ---" - ) + logger.info(f"--- Diagnostic Log for {power_name} in phase {current_phase} ---") try: orderable_locs = client.get_orderable_locations(power_name) - logger.info( - f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}" - ) + logger.info(f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}") actual_units = client.get_units(power_name) - logger.info( - f"[{power_name}][{current_phase}] Actual units: {actual_units}" - ) + logger.info(f"[{power_name}][{current_phase}] Actual units: {actual_units}") except Exception as e_diag: - logger.error( - f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}" - ) + logger.error(f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}") # Calculate possible orders possible_orders = gather_possible_orders(client.game, power_name) if not possible_orders: - logger.debug( - f"No orderable locations for {power_name}; submitting empty orders." - ) + logger.debug(f"No orderable locations for {power_name}; submitting empty orders.") await client.set_orders(power_name, []) continue @@ -559,12 +484,8 @@ async def main(): # Run order generation concurrently if order_tasks: - logger.debug( - f"Running {len(order_tasks)} order generation tasks concurrently..." - ) - order_results = await asyncio.gather( - *order_tasks, return_exceptions=True - ) + logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...") + order_results = await asyncio.gather(*order_tasks, return_exceptions=True) else: order_results = [] @@ -580,9 +501,7 @@ async def main(): ) await client.set_orders(p_name, []) elif result is None: - logger.warning( - f"get_valid_orders returned None for {p_name}. Setting empty orders." - ) + logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.") await client.set_orders(p_name, []) else: orders = result @@ -598,17 +517,13 @@ async def main(): # orders=orders # ) # await client.send_typed_message(set_orders_msg) - + await client.set_orders(p_name, orders) - logger.debug( - f"Set orders for {p_name} in {current_short_phase}: {orders}" - ) + logger.debug(f"Set orders for {p_name} in {current_short_phase}: {orders}") # Generate order diary entry try: - await agent.generate_order_diary_entry( - client.game, orders, llm_log_file_path - ) + await agent.generate_order_diary_entry(client.game, orders, llm_log_file_path) except Exception as e_diary: logger.error( f"Error generating order diary for {p_name}: {e_diary}", @@ -627,7 +542,7 @@ async def main(): # phase=current_phase # ) # await client.send_typed_message(process_msg) - + logger.info(f"Processing orders for {current_phase}...") await simulate_game_processing(client) @@ -650,22 +565,13 @@ async def main(): # Collect relationships for this phase current_relationships_for_phase = {} for power_name, agent in agents.items(): - if ( - power_name in client.powers - and not client.get_power(power_name).is_eliminated() - ): + if power_name in client.powers and not client.get_power(power_name).is_eliminated(): current_relationships_for_phase[power_name] = agent.relationships - all_phase_relationships[current_short_phase] = ( - current_relationships_for_phase - ) + all_phase_relationships[current_short_phase] = current_relationships_for_phase # Generate phase result diary entries - logger.info( - f"Generating phase result diary entries for completed phase {current_phase}..." - ) - phase_summary = getattr(client.game, "phase_summaries", {}).get( - current_phase, "(Summary not generated)" - ) + logger.info(f"Generating phase result diary entries for completed phase {current_phase}...") + phase_summary = getattr(client.game, "phase_summaries", {}).get(current_phase, "(Summary not generated)") all_orders_this_phase = current_order_history phase_result_diary_tasks = [] @@ -685,16 +591,10 @@ async def main(): await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True) # State update analysis - logger.info( - f"Starting state update analysis for completed phase {current_phase}..." - ) + logger.info(f"Starting state update analysis for completed phase {current_phase}...") current_board_state = client.get_state() - active_agent_powers = [ - (p, power) - for p, power in client.powers.items() - if p in agents and not power.is_eliminated() - ] + active_agent_powers = [(p, power) for p, power in client.powers.items() if p in agents and not power.is_eliminated()] if active_agent_powers: state_update_tasks = [] @@ -781,4 +681,3 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) - diff --git a/bot_client/models.py b/bot_client/models.py index bc238f6..839055b 100644 --- a/bot_client/models.py +++ b/bot_client/models.py @@ -8,34 +8,39 @@ and notifications (server -> client). from typing import Optional, List, Dict, Any, Union, Literal from pydantic import BaseModel, Field -from abc import ABC, abstractmethod +from abc import ABC # ============================================================================= # Base Message Classes # ============================================================================= + class BaseMessage(BaseModel, ABC): """Base class for all WebSocket messages.""" + name: str - + class Config: extra = "forbid" class BaseRequest(BaseMessage): """Base class for all client -> server requests.""" + request_id: str re_sent: bool = False class BaseResponse(BaseMessage): """Base class for all server -> client responses.""" + request_id: str class BaseNotification(BaseMessage): """Base class for all server -> client notifications (no request_id).""" + pass @@ -43,8 +48,10 @@ class BaseNotification(BaseMessage): # Authentication & Connection Level Messages # ============================================================================= + class SignInRequest(BaseRequest): """Client authentication request.""" + name: Literal["sign_in"] = "sign_in" username: str password: str @@ -52,6 +59,7 @@ class SignInRequest(BaseRequest): class GetDaidePortRequest(BaseRequest): """Request DAIDE TCP port for a game.""" + name: Literal["get_daide_port"] = "get_daide_port" game_id: str @@ -60,13 +68,16 @@ class GetDaidePortRequest(BaseRequest): # Channel Level Messages (require authentication token) # ============================================================================= + class ChannelRequest(BaseRequest): """Base class for channel-level requests that require authentication.""" + token: str class CreateGameRequest(ChannelRequest): """Create a new game.""" + name: Literal["create_game"] = "create_game" map_name: str = "standard" rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"]) @@ -78,6 +89,7 @@ class CreateGameRequest(ChannelRequest): class JoinGameRequest(ChannelRequest): """Join an existing game.""" + name: Literal["join_game"] = "join_game" game_id: str power_name: Optional[str] = None @@ -86,6 +98,7 @@ class JoinGameRequest(ChannelRequest): class JoinPowersRequest(ChannelRequest): """Join multiple powers in a game.""" + name: Literal["join_powers"] = "join_powers" game_id: str power_names: List[str] @@ -94,32 +107,37 @@ class JoinPowersRequest(ChannelRequest): class ListGamesRequest(ChannelRequest): """List available games.""" + name: Literal["list_games"] = "list_games" game_id_filter: Optional[str] = None map_name: Optional[str] = None - status: Optional[str] = None + status: Optional[str] = None include_protected: bool = False class GetPlayablePowersRequest(ChannelRequest): """Get uncontrolled powers in a game.""" + name: Literal["get_playable_powers"] = "get_playable_powers" game_id: str class GetAvailableMapsRequest(ChannelRequest): """Get list of available maps.""" + name: Literal["get_available_maps"] = "get_available_maps" class GetDummyWaitingPowersRequest(ChannelRequest): """Get AI-controllable powers (bot use).""" + name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers" game_id: str class SetGradeRequest(ChannelRequest): """Modify user permissions.""" + name: Literal["set_grade"] = "set_grade" username: str grade: str @@ -127,11 +145,13 @@ class SetGradeRequest(ChannelRequest): class DeleteAccountRequest(ChannelRequest): """Delete user account.""" + name: Literal["delete_account"] = "delete_account" class LogoutRequest(ChannelRequest): """Disconnect from server.""" + name: Literal["logout"] = "logout" @@ -139,8 +159,10 @@ class LogoutRequest(ChannelRequest): # Game Level Messages (require authentication + game context) # ============================================================================= + class GameRequest(ChannelRequest): """Base class for game-level requests.""" + game_id: str game_role: str # Power name like "ENGLAND" phase: Optional[str] = None @@ -148,18 +170,21 @@ class GameRequest(ChannelRequest): class SetOrdersRequest(GameRequest): """Submit orders for a power.""" + name: Literal["set_orders"] = "set_orders" orders: List[str] class SetWaitFlagRequest(GameRequest): """Set wait flag for turn processing.""" + name: Literal["set_wait_flag"] = "set_wait_flag" wait: bool class SendGameMessageRequest(GameRequest): """Send diplomatic message.""" + name: Literal["send_game_message"] = "send_game_message" recipient: str # Power name or "GLOBAL" message: str @@ -168,11 +193,13 @@ class SendGameMessageRequest(GameRequest): class GetAllPossibleOrdersRequest(GameRequest): """Get legal orders for current phase.""" + name: Literal["get_all_possible_orders"] = "get_all_possible_orders" class GetPhaseHistoryRequest(GameRequest): """Get historical game phases.""" + name: Literal["get_phase_history"] = "get_phase_history" from_phase: Optional[str] = None to_phase: Optional[str] = None @@ -180,45 +207,53 @@ class GetPhaseHistoryRequest(GameRequest): class ProcessGameRequest(GameRequest): """Force game processing (master only).""" + name: Literal["process_game"] = "process_game" class VoteRequest(GameRequest): """Vote for/against draw.""" + name: Literal["vote"] = "vote" vote: Literal["yes", "no"] class SaveGameRequest(GameRequest): """Export game as JSON.""" + name: Literal["save_game"] = "save_game" class SetGameStateRequest(GameRequest): """Modify game state (master only).""" + name: Literal["set_game_state"] = "set_game_state" state: Dict[str, Any] class SetGameStatusRequest(GameRequest): """Change game status (master only).""" + name: Literal["set_game_status"] = "set_game_status" status: str class SetDummyPowersRequest(GameRequest): """Make powers AI-controlled (master only).""" + name: Literal["set_dummy_powers"] = "set_dummy_powers" power_names: List[str] class DeleteGameRequest(GameRequest): """Delete game (master only).""" + name: Literal["delete_game"] = "delete_game" class LeaveGameRequest(GameRequest): """Leave game.""" + name: Literal["leave_game"] = "leave_game" @@ -226,13 +261,16 @@ class LeaveGameRequest(GameRequest): # Response Messages (Server -> Client) # ============================================================================= + class OkResponse(BaseResponse): """Generic success response.""" + name: Literal["ok"] = "ok" class ErrorResponse(BaseResponse): """Error response with error type and message.""" + name: Literal["error"] = "error" error_type: str message: str @@ -240,60 +278,70 @@ class ErrorResponse(BaseResponse): class DataTokenResponse(BaseResponse): """Contains authentication token.""" + name: Literal["data_token"] = "data_token" data: str # The authentication token class DataGameResponse(BaseResponse): """Contains full game object.""" + name: Literal["data_game"] = "data_game" data: Dict[str, Any] # The complete game state class DataGameInfoResponse(BaseResponse): """Contains game metadata.""" + name: Literal["data_game_info"] = "data_game_info" data: Dict[str, Any] class DataGamesResponse(BaseResponse): """List of game information.""" + name: Literal["data_games"] = "data_games" data: List[Dict[str, Any]] class DataMapsResponse(BaseResponse): """Available maps information.""" + name: Literal["data_maps"] = "data_maps" data: List[str] class DataPowerNamesResponse(BaseResponse): """List of power names.""" + name: Literal["data_power_names"] = "data_power_names" data: List[str] class DataPossibleOrdersResponse(BaseResponse): """Legal orders and locations.""" + name: Literal["data_possible_orders"] = "data_possible_orders" data: Dict[str, List[str]] # Location -> list of possible orders class DataGamePhasesResponse(BaseResponse): """Historical game phases.""" + name: Literal["data_game_phases"] = "data_game_phases" data: List[Dict[str, Any]] class DataSavedGameResponse(BaseResponse): """Exported game JSON.""" + name: Literal["data_saved_game"] = "data_saved_game" data: Dict[str, Any] class DataPortResponse(BaseResponse): """DAIDE port number.""" + name: Literal["data_port"] = "data_port" data: int @@ -302,8 +350,10 @@ class DataPortResponse(BaseResponse): # Notification Messages (Server -> Client) # ============================================================================= + class GameProcessedNotification(BaseNotification): """Phase completed, new orders phase.""" + name: Literal["game_processed"] = "game_processed" game_id: str phase: str @@ -312,6 +362,7 @@ class GameProcessedNotification(BaseNotification): class GamePhaseUpdateNotification(BaseNotification): """Game state changed.""" + name: Literal["game_phase_update"] = "game_phase_update" game_id: str phase: str @@ -320,6 +371,7 @@ class GamePhaseUpdateNotification(BaseNotification): class GameStatusUpdateNotification(BaseNotification): """Game status changed (forming/active/paused/completed).""" + name: Literal["game_status_update"] = "game_status_update" game_id: str status: str @@ -327,6 +379,7 @@ class GameStatusUpdateNotification(BaseNotification): class PowersControllersNotification(BaseNotification): """Power control assignments changed.""" + name: Literal["powers_controllers"] = "powers_controllers" game_id: str controllers: Dict[str, str] # Power -> Controller mapping @@ -334,6 +387,7 @@ class PowersControllersNotification(BaseNotification): class PowerOrdersUpdateNotification(BaseNotification): """Player submitted new orders.""" + name: Literal["power_orders_update"] = "power_orders_update" game_id: str power_name: str @@ -343,15 +397,17 @@ class PowerOrdersUpdateNotification(BaseNotification): class PowerOrdersFlagNotification(BaseNotification): """Player order submission status.""" + name: Literal["power_orders_flag"] = "power_orders_flag" game_id: str - power_name: str + power_name: str order_is_set: bool phase: str class PowerWaitFlagNotification(BaseNotification): """Player wait flag changed.""" + name: Literal["power_wait_flag"] = "power_wait_flag" game_id: str power_name: str @@ -360,6 +416,7 @@ class PowerWaitFlagNotification(BaseNotification): class GameMessageReceivedNotification(BaseNotification): """Diplomatic message received.""" + name: Literal["game_message_received"] = "game_message_received" game_id: str sender: str @@ -371,6 +428,7 @@ class GameMessageReceivedNotification(BaseNotification): class VoteUpdatedNotification(BaseNotification): """Draw votes changed (omniscient view).""" + name: Literal["vote_updated"] = "vote_updated" game_id: str votes: Dict[str, str] # Power -> vote mapping @@ -378,6 +436,7 @@ class VoteUpdatedNotification(BaseNotification): class VoteCountUpdatedNotification(BaseNotification): """Vote count changed (observer view).""" + name: Literal["vote_count_updated"] = "vote_count_updated" game_id: str count_yes: int @@ -386,6 +445,7 @@ class VoteCountUpdatedNotification(BaseNotification): class PowerVoteUpdatedNotification(BaseNotification): """Own power's vote changed.""" + name: Literal["power_vote_updated"] = "power_vote_updated" game_id: str power_name: str @@ -394,12 +454,14 @@ class PowerVoteUpdatedNotification(BaseNotification): class GameDeletedNotification(BaseNotification): """Game removed from server.""" + name: Literal["game_deleted"] = "game_deleted" game_id: str class OmniscientUpdatedNotification(BaseNotification): """Observer permissions changed.""" + name: Literal["omniscient_updated"] = "omniscient_updated" game_id: str omniscient_type: str @@ -407,12 +469,14 @@ class OmniscientUpdatedNotification(BaseNotification): class AccountDeletedNotification(BaseNotification): """User account deleted.""" + name: Literal["account_deleted"] = "account_deleted" username: str class ClearedCentersNotification(BaseNotification): """Supply centers cleared.""" + name: Literal["cleared_centers"] = "cleared_centers" game_id: str power_name: str @@ -420,6 +484,7 @@ class ClearedCentersNotification(BaseNotification): class ClearedOrdersNotification(BaseNotification): """Orders cleared.""" + name: Literal["cleared_orders"] = "cleared_orders" game_id: str power_name: str @@ -428,6 +493,7 @@ class ClearedOrdersNotification(BaseNotification): class ClearedUnitsNotification(BaseNotification): """Units cleared.""" + name: Literal["cleared_units"] = "cleared_units" game_id: str power_name: str @@ -514,24 +580,25 @@ WebSocketMessage = Union[RequestMessage, ResponseMessage, NotificationMessage] # Utility Functions # ============================================================================= + def parse_message(data: Dict[str, Any]) -> WebSocketMessage: """ Parse a raw WebSocket message dictionary into the appropriate pydantic model. - + Args: data: Raw message dictionary from WebSocket - + Returns: Parsed message object - + Raises: ValueError: If message cannot be parsed or is of unknown type """ if not isinstance(data, dict) or "name" not in data: raise ValueError("Invalid message format: missing 'name' field") - + message_name = data["name"] - + # Map message names to their corresponding classes message_classes = { # Requests @@ -592,22 +659,22 @@ def parse_message(data: Dict[str, Any]) -> WebSocketMessage: "cleared_orders": ClearedOrdersNotification, "cleared_units": ClearedUnitsNotification, } - + message_class = message_classes.get(message_name) if message_class is None: raise ValueError(f"Unknown message type: {message_name}") - + return message_class(**data) def serialize_message(message: WebSocketMessage) -> Dict[str, Any]: """ Serialize a pydantic message object to a dictionary for WebSocket transmission. - + Args: message: Pydantic message object - + Returns: Dictionary representation of the message """ - return message.model_dump(exclude_none=True) \ No newline at end of file + return message.model_dump(exclude_none=True) diff --git a/bot_client/multi_bot_launcher.py b/bot_client/multi_bot_launcher.py index 2c938b6..a8fecfa 100644 --- a/bot_client/multi_bot_launcher.py +++ b/bot_client/multi_bot_launcher.py @@ -50,13 +50,13 @@ class MultiBotLauncher: # Default power to model mapping self.default_models = { - "AUSTRIA": "gpt-3.5-turbo", - "ENGLAND": "gpt-4", - "FRANCE": "claude-3-haiku", - "GERMANY": "gpt-3.5-turbo", - "ITALY": "gemini-pro", - "RUSSIA": "gpt-4", - "TURKEY": "claude-3-sonnet", + "AUSTRIA": "gemini-2.5-flash-lite-preview-06-17", + "ENGLAND": "gemini-2.5-flash-lite-preview-06-17", + "FRANCE": "gemini-2.5-flash-lite-preview-06-17", + "GERMANY": "gemini-2.5-flash-lite-preview-06-17", + "ITALY": "gemini-2.5-flash-lite-preview-06-17", + "RUSSIA": "gemini-2.5-flash-lite-preview-06-17", + "TURKEY": "gemini-2.5-flash-lite-preview-06-17", } async def create_game(self, creator_power: str = "FRANCE") -> str: @@ -99,7 +99,12 @@ class MultiBotLauncher: return game_id def launch_bot( - self, power: str, model: str, game_id: str, log_level: str = "INFO" + self, + power: str, + model: str, + game_id: str, + log_level: str = "INFO", + negotiation_rounds: int = 3, ) -> subprocess.Popen: """ Launch a single bot process. @@ -134,6 +139,8 @@ class MultiBotLauncher: game_id, "--log-level", log_level, + "--negotiation-rounds", + str(negotiation_rounds), ] logger.info(f"Launching bot for {power} with model {model}") @@ -156,7 +163,8 @@ class MultiBotLauncher: models: Optional[Dict[str, str]] = None, powers: Optional[List[str]] = None, log_level: str = "INFO", - stagger_delay: float = 2.0, + stagger_delay: float = 0.5, + negotiation_rounds: int = 3, ): """ Launch bots for all specified powers. @@ -180,13 +188,11 @@ class MultiBotLauncher: model = models.get(power, "gpt-3.5-turbo") try: - process = self.launch_bot(power, model, game_id, log_level) + process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds) self.bot_processes.append(process) self.process_to_power[process] = power - logger.info( - f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})" - ) + logger.info(f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})") # Stagger the launches to avoid overwhelming the server if i < len(powers) - 1: # Don't delay after the last bot @@ -209,11 +215,11 @@ class MultiBotLauncher: try: while self.bot_processes: active_processes = [] - + # Collect all stdout file descriptors from active processes stdout_fds = [] fd_to_process = {} - + for process in self.bot_processes: if process.poll() is None: # Still running active_processes.append(process) @@ -224,15 +230,13 @@ class MultiBotLauncher: # Process has ended return_code = process.returncode power = self.process_to_power.get(process, "UNKNOWN") - logger.info( - f"{power} bot process {process.pid} ended with code {return_code}" - ) + logger.info(f"{power} bot process {process.pid} ended with code {return_code}") # Read any remaining output remaining_output = process.stdout.read() if remaining_output: print(f"{power}_{process.pid} final output: {remaining_output}") - + # Clean up the power mapping self.process_to_power.pop(process, None) @@ -241,20 +245,20 @@ class MultiBotLauncher: if not self.bot_processes: logger.info("All bots have finished") break - + # Use select to check which processes have output ready (Unix only) - if stdout_fds and hasattr(select, 'select'): + if stdout_fds and hasattr(select, "select"): try: ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout - + for fd in ready_fds: process = fd_to_process[fd] power = self.process_to_power.get(process, "UNKNOWN") - + # Read available lines (but limit to prevent monopolizing) lines_read = 0 max_lines_per_process = 10 - + while lines_read < max_lines_per_process: try: line = process.stdout.readline() @@ -264,7 +268,7 @@ class MultiBotLauncher: lines_read += 1 except: break - + except (OSError, ValueError): # Fallback if select fails self._fallback_read_output(active_processes) @@ -283,11 +287,11 @@ class MultiBotLauncher: """Fallback method for reading output when select is not available.""" for process in active_processes: power = self.process_to_power.get(process, "UNKNOWN") - + # Read limited lines per process to prevent monopolizing lines_read = 0 max_lines_per_process = 3 # More conservative for fallback - + while lines_read < max_lines_per_process: try: line = process.stdout.readline() @@ -324,6 +328,7 @@ class MultiBotLauncher: models: Optional[Dict[str, str]] = None, log_level: str = "INFO", creator_power: str = "FRANCE", + negotiation_rounds: int = 3, ): """ Create a game and launch all bots for a complete game. @@ -342,7 +347,12 @@ class MultiBotLauncher: await asyncio.sleep(2) # Launch all bots - await self.launch_all_bots(game_id, models, log_level=log_level) + await self.launch_all_bots( + game_id, + models, + log_level=log_level, + negotiation_rounds=negotiation_rounds, + ) # Monitor the bots self.monitor_bots() @@ -358,6 +368,7 @@ class MultiBotLauncher: powers: List[str], models: Optional[Dict[str, str]] = None, log_level: str = "INFO", + negotiation_rounds: int = 3, ): """ Launch bots to join an existing game. @@ -372,7 +383,13 @@ class MultiBotLauncher: self.game_id = game_id # Launch bots for specified powers - await self.launch_all_bots(game_id, models, powers, log_level) + await self.launch_all_bots( + game_id, + models, + powers, + log_level, + negotiation_rounds=negotiation_rounds, + ) # Monitor the bots self.monitor_bots() @@ -391,18 +408,16 @@ def parse_arguments(): parser.add_argument("--port", type=int, default=8432, help="Server port") parser.add_argument("--username-base", default="bot", help="Base username for bots") parser.add_argument("--password", default="password", help="Password for all bots") - parser.add_argument( - "--game-id", help="Game ID to join (creates new if not specified)" - ) - parser.add_argument( - "--powers", nargs="+", help="Powers to launch bots for (default: all)" - ) - parser.add_argument( - "--models", help="Comma-separated list of models in power order" - ) + parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)") + parser.add_argument("--powers", nargs="+", help="Powers to launch bots for (default: all)") + parser.add_argument("--models", help="Comma-separated list of models in power order") parser.add_argument("--log-level", default="INFO", help="Logging level") + parser.add_argument("--creator-power", default="FRANCE", help="Power that creates the game") parser.add_argument( - "--creator-power", default="FRANCE", help="Power that creates the game" + "--negotiation-rounds", + type=int, + default=3, + help="Number of negotiation rounds per movement phase (default: 3)", ) return parser.parse_args() @@ -425,9 +440,7 @@ async def main(): model_list = [m.strip() for m in args.models.split(",")] powers = args.powers or list(launcher.default_models.keys()) if len(model_list) != len(powers): - logger.error( - f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})" - ) + logger.error(f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})") return models = dict(zip(powers, model_list)) @@ -440,6 +453,7 @@ async def main(): powers=powers, models=models, log_level=args.log_level, + negotiation_rounds=args.negotiation_rounds, ) else: # Create new game and launch all bots @@ -447,6 +461,7 @@ async def main(): models=models, log_level=args.log_level, creator_power=args.creator_power, + negotiation_rounds=args.negotiation_rounds, ) except KeyboardInterrupt: diff --git a/bot_client/pyproject.toml b/bot_client/pyproject.toml index 33446ba..412cf92 100644 --- a/bot_client/pyproject.toml +++ b/bot_client/pyproject.toml @@ -29,6 +29,7 @@ include = [".", ".."] [tool.ruff] src = ["..", "."] +line-length = 150 [tool.pytest.ini_options] testpaths = ["bot_client/tests"] @@ -44,8 +45,4 @@ markers = [ ] [dependency-groups] -dev = [ - "pytest>=8.4.1", - "pytest-asyncio>=0.21.0", - "websockets>=15.0.1", -] +dev = ["pytest>=8.4.1", "pytest-asyncio>=0.21.0", "websockets>=15.0.1"] diff --git a/bot_client/single_bot_player.py b/bot_client/single_bot_player.py index 146858c..d4fa7fb 100644 --- a/bot_client/single_bot_player.py +++ b/bot_client/single_bot_player.py @@ -13,25 +13,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import argparse import asyncio import signal -from typing import Optional, Dict +from typing import Optional, Dict, List import dotenv from loguru import logger -from websocket_diplomacy_client import ( - WebSocketDiplomacyClient, - connect_to_diplomacy_server, -) - -from diplomacy.utils.exceptions import DiplomacyException, GameIdException - -# Suppress warnings -# os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40" -# os.environ["GRPC_VERBOSITY"] = "ERROR" -# os.environ["ABSL_MIN_LOG_LEVEL"] = "2" -# os.environ["GRPC_POLL_STRATEGY"] = "poll" +from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server +from diplomacy.engine.power import Power +from diplomacy.utils.exceptions import GameIdException +from diplomacy.communication.notifications import GameStatusUpdate from diplomacy.engine.message import Message from ai_diplomacy.clients import load_model_client @@ -40,15 +32,16 @@ from ai_diplomacy.game_history import GameHistory from ai_diplomacy.agent import DiplomacyAgent from ai_diplomacy.initialization import initialize_agent_state_ext from config import Configuration +from websocket_negotiations import ( + conduct_strategic_negotiation_round, + should_participate_in_negotiations, + get_negotiation_delay, +) dotenv.load_dotenv() -config = Configuration() - -if config.DEBUG: - import tracemalloc - - tracemalloc.start() +# TODO: This, but better +config: Configuration = None class SingleBotPlayer: @@ -70,6 +63,7 @@ class SingleBotPlayer: hostname: str = "localhost", port: int = 8432, game_id: Optional[str] = None, + negotiation_rounds: int = 3, ): assert username is not None assert password is not None @@ -87,16 +81,24 @@ class SingleBotPlayer: # Bot state self.client: WebSocketDiplomacyClient self.agent: DiplomacyAgent - self.game_history = GameHistory() + self.game_history: GameHistory = GameHistory() self.running = True self.current_phase = None self.waiting_for_orders = False self.orders_submitted = False + # Negotiation settings + self.negotiation_rounds = negotiation_rounds + self.current_negotiation_round = 0 + self.negotiation_complete = False + # Track error stats - self.error_stats: Dict[str, Dict[str, int]] = { - self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0} - } + self.error_stats: Dict[str, Dict[str, int]] = {self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}} + + # Track messaging patterns for strategic communication + self.message_counts: Dict[str, int] = {} # Messages received from each power + self.response_counts: Dict[str, int] = {} # Responses sent to each power + self.priority_contacts: List[str] = [] # Powers to prioritize for communication # Setup signal handlers for graceful shutdown signal.signal(signal.SIGINT, self._signal_handler) @@ -122,9 +124,7 @@ class SingleBotPlayer: # Join or create game if self.game_id: logger.info(f"Joining existing game {self.game_id} as {self.power_name}") - game = await self.client.join_game( - game_id=self.game_id, power_name=self.power_name - ) + await self.client.join_game(game_id=self.game_id, power_name=self.power_name) else: logger.info(f"Creating new game as {self.power_name}") await self.client.create_game( @@ -142,9 +142,7 @@ class SingleBotPlayer: self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client) # Initialize agent state - await initialize_agent_state_ext( - self.agent, self.client.game, self.game_history, config.log_file_path - ) + await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path) # Setup game event callbacks await self._setup_event_callbacks() @@ -198,6 +196,16 @@ class SingleBotPlayer: self.current_phase = new_phase self.game_history.add_phase(new_phase) self.orders_submitted = False + self.current_negotiation_round = 0 + self.negotiation_complete = False + + # Log message statistics at phase transitions + if hasattr(self, "message_counts") and self.message_counts: + self.log_message_statistics() + + # Check if we should start negotiations for movement phases (not adjustment phases) + if new_phase.endswith("M"): + await self._handle_negotiation_phase() # Check if we need to submit orders for this new phase await self._check_if_orders_needed() @@ -223,9 +231,7 @@ class SingleBotPlayer: def _on_message_received(self, game, notification): """Handle incoming diplomatic messages.""" message = notification.message - logger.info( - f"Message received from {message.sender} to {message.recipient}: {message.message}" - ) + logger.info(f"Message received from {message.sender} to {message.recipient}: {message.message}") # Add message to game history self.game_history.add_message( @@ -235,12 +241,17 @@ class SingleBotPlayer: message_content=message.message, ) + # Track message patterns + if message.recipient == self.power_name and message.sender != self.power_name: + self.message_counts[message.sender] = self.message_counts.get(message.sender, 0) + 1 + self._update_priority_contacts() + # If it's a private message to us, consider responding if message.recipient == self.power_name and message.sender != self.power_name: # Schedule the async processing in the event loop asyncio.create_task(self._consider_message_response(message)) - def _on_status_update(self, game, notification): + def _on_status_update(self, game, notification: GameStatusUpdate): """Handle game status changes.""" logger.info(f"Game status updated: {notification.status}") @@ -261,20 +272,50 @@ class SingleBotPlayer: # Check if it's a phase where we can submit orders current_short_phase = self.client.game.current_short_phase - # We submit orders in Movement and Retreat phases - if current_short_phase.endswith("M") or current_short_phase.endswith("R"): - # Check if we have units that can receive orders - orderable_locations = self.client.game.get_orderable_locations( - self.power_name - ) - if orderable_locations: - logger.info(f"Orders needed for phase {current_short_phase}") - self.waiting_for_orders = True - await self._submit_orders() - else: - logger.info( - f"No orderable locations for {self.power_name} in {current_short_phase}" - ) + logger.debug(f"Checking if orders needed for phase: {current_short_phase}") + + # Movement and Retreat phases + orderable_locations = self.client.game.get_orderable_locations(self.power_name) + if orderable_locations: + logger.info(f"Orders needed for {current_short_phase} phase - orderable locations: {orderable_locations}") + self.waiting_for_orders = True + await self._submit_orders() + else: + logger.info(f"No orderable locations for {self.power_name} in {current_short_phase}") + + async def _submit_adjustment_orders(self, action_type, count): + """Submit build or disband orders for adjustment phase.""" + # FIXME: This whole function is horse shit. + return + if action_type == "build": + # Get buildable locations for this power + possible_orders = gather_possible_order(self.client.game, self.power_name) + if not possible_orders: + logger.warning(f"No possible build orders for {self.power_name}") + await self.client.set_orders(self.power_name, []) + self.orders_submitted = True + return + + # Filter for build orders (usually start with unit type + location) + build_orders = [order for order in possible_orders if " - " not in order and any(order.startswith(unit) for unit in ["A ", "F "])] + + logger.info(f"Available build orders for {self.power_name}: {build_orders}") + + # Select up to 'count' build orders + selected_orders = build_orders[:count] + logger.info(f"Submitting build orders for {self.power_name}: {selected_orders}") + + await self.client.set_orders(self.power_name, selected_orders) + + elif action_type == "disband": + # Get current units for disbanding + + logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}") + await self.client.set_orders(self.power_name, disband_orders) + + self.orders_submitted = True + self.waiting_for_orders = False + logger.info(f"Adjustment orders submitted successfully for {self.power_name}") async def _submit_orders(self): """Generate and submit orders for the current phase.""" @@ -282,108 +323,280 @@ class SingleBotPlayer: logger.debug("Orders already submitted for this phase") return - try: - logger.info("Generating orders...") + current_phase = self.client.game.get_current_phase() + logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...") - # Get current board state - board_state = self.client.game.get_state() + # Get current board state + board_state = self.client.game.get_state() - # Get possible orders - possible_orders = gather_possible_orders(self.client.game, self.power_name) + # Get possible orders + possible_orders = gather_possible_orders(self.client.game, self.power_name) - if not possible_orders: - logger.info("No possible orders, submitting empty order set") - await self.client.set_orders(self.power_name, []) - self.orders_submitted = True - return - - # Generate orders using AI - orders = await get_valid_orders( - game=self.client.game, - client=self.agent.client, - board_state=board_state, - power_name=self.power_name, - possible_orders=possible_orders, - game_history=self.game_history, - model_error_stats=self.error_stats, - agent_goals=self.agent.goals, - agent_relationships=self.agent.relationships, - agent_private_diary_str=self.agent.format_private_diary_for_prompt(), - phase=self.client.game.get_current_phase(), - ) - - # Submit orders - if orders: - logger.info(f"Submitting orders: {orders}") - await self.client.set_orders(self.power_name, orders) - - # Generate order diary entry - await self.agent.generate_order_diary_entry( - self.client.game, - orders, - config.log_file_path, - ) - else: - logger.info("No valid orders generated, submitting empty order set") - await self.client.set_orders(self.power_name, []) + logger.debug(f"Possible orders for {self.power_name}: {possible_orders}") + if not possible_orders: + logger.info(f"No possible orders for {self.power_name}, submitting empty order set") + await self.client.set_orders(self.power_name, []) self.orders_submitted = True - self.waiting_for_orders = False - logger.info("Orders submitted successfully") - # Call the no wait so we don't sit around for the turns to end. - self.client.game.no_wait() + return - except DiplomacyException as e: - logger.error(f"Error submitting orders: {e}", exc_info=True) - # FIXME: I don't think we want to do this. Likely want to retry again multiple times. - # - # Submit empty orders as fallback - try: - await self.client.set_orders(self.power_name, []) - self.orders_submitted = True - except Exception as fallback_error: - logger.error(f"Failed to submit fallback orders: {fallback_error}") + # Generate orders using AI + orders = await get_valid_orders( + game=self.client.game, + client=self.agent.client, + board_state=board_state, + power_name=self.power_name, + possible_orders=possible_orders, + game_history=self.game_history, + model_error_stats=self.error_stats, + agent_goals=self.agent.goals, + agent_relationships=self.agent.relationships, + agent_private_diary_str=self.agent.format_private_diary_for_prompt(), + phase=self.client.game.get_current_phase(), + ) + + # Submit orders + if orders: + logger.info(f"Submitting orders: {orders}") + await self.client.set_orders(self.power_name, orders) + + # Generate order diary entry + await self.agent.generate_order_diary_entry( + self.client.game, + orders, + config.log_file_path, + ) + else: + logger.info("No valid orders generated, submitting empty order set") + await self.client.set_orders(self.power_name, []) + + self.orders_submitted = True + self.waiting_for_orders = False + logger.info("Orders submitted successfully") + # Call the no wait so we don't sit around for the turns to end. + # TODO: We probably don't want to call this here. + # We want to call it when negotiations end, + self.client.game.no_wait() async def _analyze_phase_results(self): """Analyze the results of the previous phase.""" - try: - logger.info("Analyzing phase results...") + logger.info("Analyzing phase results...") - # Get current board state after processing - board_state = self.client.game.get_state() + # Get current board state after processing + board_state = self.client.game.get_state() - # Generate a simple phase summary - phase_summary = f"Phase {self.current_phase} completed." + # Generate a simple phase summary + phase_summary = f"Phase {self.current_phase} completed." - # Update agent state based on results - await self.agent.analyze_phase_and_update_state( - game=self.client.game, - board_state=board_state, - phase_summary=phase_summary, + # Update agent state based on results + await self.agent.analyze_phase_and_update_state( + game=self.client.game, + board_state=board_state, + phase_summary=phase_summary, + game_history=self.game_history, + log_file_path=config.log_file_path, + ) + + logger.info("Phase analysis complete") + + async def _handle_negotiation_phase(self): + """Handle the negotiation phase for movement turns.""" + # Check if we should participate in negotiations + if not await should_participate_in_negotiations(self.client, self.agent): + logger.info(f"{self.power_name} will not participate in negotiations this phase") + self.negotiation_complete = True + return + + logger.info(f"Starting negotiation phase for {self.power_name}") + # TODO: This doesn't need a specific number of negotiation rounds, though it should have a top number of messages this turn so they don't blabber on forever. + + # Conduct negotiations for the specified number of rounds + for round_num in range(1, self.negotiation_rounds + 1): + self.current_negotiation_round = round_num + + logger.info(f"Negotiation round {round_num}/{self.negotiation_rounds} for {self.power_name}") + + # Use strategic negotiation that analyzes recent messages + success = await conduct_strategic_negotiation_round( + client=self.client, + agent=self.agent, game_history=self.game_history, + model_error_stats=self.error_stats, log_file_path=config.log_file_path, + round_number=round_num, + max_rounds=self.negotiation_rounds, ) - logger.info("Phase analysis complete") + if not success: + logger.info(f"No messages sent in round {round_num} for {self.power_name}") - except Exception as e: - logger.error(f"Error analyzing phase results: {e}", exc_info=True) + # Wait between rounds to allow other bots to respond + if round_num < self.negotiation_rounds: + delay = get_negotiation_delay(round_num, self.negotiation_rounds) + logger.debug(f"Waiting {delay}s before next negotiation round") + await asyncio.sleep(delay) + + self.negotiation_complete = True + logger.info(f"Negotiation phase complete for {self.power_name}") async def _consider_message_response(self, message: Message): """Consider whether to respond to a diplomatic message.""" try: - # Simple logic: if someone greets us, greet back - if any( - word in message.message.lower() for word in ["hello", "hi", "greetings"] - ): - response = f"Hello {message.sender}! Good to hear from you." - await self.client.game.send_game_message( - sender=self.power_name, recipient=message.sender, message=response + # Only respond to messages directed at us specifically + if message.recipient != self.power_name: + return + + # Don't respond to our own messages + if message.sender == self.power_name: + return + + logger.info(f"Considering response to message from {message.sender}: {message.message[:50]}...") + + # Enhanced heuristic: respond to direct questions, proposals, and strategic keywords + message_lower = message.message.lower() + strategic_keywords = [ + "alliance", + "deal", + "propose", + "agreement", + "support", + "attack", + "coordinate", + "move", + "order", + "help", + "work together", + "partner", + "enemy", + "threat", + "negotiate", + "discuss", + "plan", + "strategy", + "bounce", + "convoy", + "retreat", + ] + + should_respond = any( + [ + "?" in message.message, # Questions + any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings + any(keyword in message_lower for keyword in strategic_keywords), # Strategic content + len(message.message.split()) > 15, # Longer messages suggest they want engagement + message.sender in self.priority_contacts, # Priority contacts + ] + ) + + if should_respond: + # Generate a contextual response using AI + # Get current game state for context + board_state = self.client.get_state() + possible_orders = gather_possible_orders(self.client.game, self.power_name) + + # Create a simple conversation context + active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated()] + + # Generate response using the agent's conversation capabilities + responses = await self.agent.client.get_conversation_reply( + game=self.client.game, + board_state=board_state, + power_name=self.power_name, + possible_orders=possible_orders, + game_history=self.game_history, + game_phase=self.client.get_current_short_phase(), + log_file_path=config.log_file_path, + active_powers=active_powers, + agent_goals=self.agent.goals, + agent_relationships=self.agent.relationships, + agent_private_diary_str=self.agent.format_private_diary_for_prompt(), ) - logger.info(f"Sent response to {message.sender}: {response}") + + # Send the first response if any were generated + if responses and len(responses) > 0: + response_content = responses[0].get("content", "").strip() + if response_content: + await self.client.send_message( + sender=self.power_name, + recipient=message.sender, + message=response_content, + phase=self.client.get_current_short_phase(), + ) + + # Add to game history + self.game_history.add_message( + phase_name=self.client.get_current_short_phase(), + sender=self.power_name, + recipient=message.sender, + message_content=response_content, + ) + + # Track response patterns + self.response_counts[message.sender] = self.response_counts.get(message.sender, 0) + 1 + + # Add to agent's journal + self.agent.add_journal_entry( + f"Responded to {message.sender} in {self.client.get_current_short_phase()}: {response_content[:100]}..." + ) + + logger.info(f"Sent AI response to {message.sender}: {response_content[:50]}...") + else: + logger.debug(f"AI generated empty response to {message.sender}") + else: + logger.debug(f"AI generated no responses to {message.sender}") + else: + logger.debug(f"Decided not to respond to message from {message.sender}") except Exception as e: - logger.error(f"Error responding to message: {e}") + logger.error(f"Error responding to message: {e}", exc_info=True) + + def _update_priority_contacts(self) -> None: + """Update the list of priority contacts based on messaging patterns.""" + # Sort powers by message count (descending) and take top 3-4 + sorted_contacts = sorted(self.message_counts.items(), key=lambda x: x[1], reverse=True) + + # Keep top 4 most active contacts as priority + self.priority_contacts = [contact[0] for contact in sorted_contacts[:4]] + + logger.debug(f"Updated priority contacts for {self.power_name}: {self.priority_contacts}") + + def get_message_statistics(self) -> Dict[str, any]: + """Get comprehensive statistics about messaging patterns.""" + active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated() and p_name != self.power_name] + + stats = { + "power_name": self.power_name, + "total_messages_received": sum(self.message_counts.values()), + "total_responses_sent": sum(self.response_counts.values()), + "message_counts_by_power": dict(self.message_counts), + "response_counts_by_power": dict(self.response_counts), + "priority_contacts": list(self.priority_contacts), + "response_rate_by_power": {}, + "active_powers": active_powers, + "current_phase": self.current_phase, + } + + # Calculate response rates + for power, received in self.message_counts.items(): + sent = self.response_counts.get(power, 0) + stats["response_rate_by_power"][power] = sent / received if received > 0 else 0.0 + + return stats + + def log_message_statistics(self) -> None: + """Log current message statistics for analysis.""" + stats = self.get_message_statistics() + + logger.info(f"Message Statistics for {self.power_name}:") + logger.info(f" Total messages received: {stats['total_messages_received']}") + logger.info(f" Total responses sent: {stats['total_responses_sent']}") + logger.info(f" Priority contacts: {stats['priority_contacts']}") + + for power in stats["active_powers"]: + received = stats["message_counts_by_power"].get(power, 0) + sent = stats["response_counts_by_power"].get(power, 0) + rate = stats["response_rate_by_power"].get(power, 0.0) + logger.info(f" {power}: {received} received, {sent} sent, {rate:.1%} response rate") async def run(self): """Main bot loop.""" @@ -408,9 +621,7 @@ class SingleBotPlayer: else: logger.info("Bot shutting down") except GameIdException: - logger.error( - f"Game with id {self.game_id} does not exist on the server. Exiting..." - ) + logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...") finally: await self.cleanup() @@ -437,10 +648,14 @@ def parse_arguments(): parser.add_argument("--password", default="password", help="Bot password") parser.add_argument("--power", default="FRANCE", help="Power to control") parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use") - parser.add_argument( - "--game-id", help="Game ID to join (creates new if not specified)" - ) + parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)") parser.add_argument("--log-level", default="INFO", help="Logging level") + parser.add_argument( + "--negotiation-rounds", + type=int, + default=3, + help="Number of negotiation rounds per movement phase (default: 3)", + ) return parser.parse_args() @@ -459,8 +674,9 @@ async def main(): power_name=args.power, model_name=args.model, game_id=args.game_id, + negotiation_rounds=args.negotiation_rounds, ) - + config = Configuration(args.power) await bot.run() diff --git a/bot_client/websocket_diplomacy_client.py b/bot_client/websocket_diplomacy_client.py index 80b8777..b972a05 100644 --- a/bot_client/websocket_diplomacy_client.py +++ b/bot_client/websocket_diplomacy_client.py @@ -27,9 +27,7 @@ class WebSocketDiplomacyClient: but communicates with a remote server via WebSocket connections. """ - def __init__( - self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False - ): + def __init__(self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False): """ Initialize the WebSocket client. @@ -199,9 +197,7 @@ class WebSocketDiplomacyClient: return await self.channel.get_available_maps() - async def set_orders( - self, power_name: str, orders: List[str], wait: Optional[bool] = None - ) -> None: + async def set_orders(self, power_name: str, orders: List[str], wait: Optional[bool] = None) -> None: """ Set orders for a power. @@ -246,9 +242,7 @@ class WebSocketDiplomacyClient: else: await self.game.no_wait() - async def send_message( - self, sender: str, recipient: str, message: str, phase: Optional[str] = None - ) -> None: + async def send_message(self, sender: str, recipient: str, message: str, phase: Optional[str] = None) -> None: """ Send a diplomatic message. @@ -269,6 +263,57 @@ class WebSocketDiplomacyClient: logger.debug(f"Sending message from {sender} to {recipient}: {message}") await self.game.send_game_message(message=msg) + async def send_broadcast_message(self, sender: str, message: str, phase: Optional[str] = None) -> None: + """ + Send a broadcast message to all active powers. + + Args: + sender: Sending power name + message: Message content + phase: Game phase (uses current phase if None) + """ + if not self.game: + raise DiplomacyException("Must join a game first") + + if phase is None: + phase = self.game.current_short_phase + + # Send to all active powers + active_powers = [power_name for power_name, power in self.powers.items() if not power.is_eliminated() and power_name != sender] + + for recipient in active_powers: + await self.send_message(sender, recipient, message, phase) + + logger.debug(f"Broadcast message from {sender} to {len(active_powers)} powers: {message}") + + async def get_recent_messages(self, phase: Optional[str] = None, limit: int = 50) -> List[Message]: + """ + Get recent messages from the game. + + Args: + phase: Filter by specific phase (None for current phase) + limit: Maximum number of messages to return + + Returns: + List of recent messages + """ + if not self.game: + raise DiplomacyException("Must join a game first") + + if phase is None: + phase = self.game.current_short_phase + + # Get messages from the game and filter by phase + all_messages = list(self.messages.values()) + + # Filter by phase if specified + filtered_messages = [msg for msg in all_messages if msg.phase == phase] + + # Sort by time (most recent first) and limit + filtered_messages.sort(key=lambda m: m.time_sent or 0, reverse=True) + + return filtered_messages[:limit] + async def process_game(self) -> None: """ Force the game to process immediately (admin/moderator only). @@ -289,9 +334,7 @@ class WebSocketDiplomacyClient: logger.debug("Synchronizing game state") await self.game.synchronize() - async def get_phase_history( - self, from_phase: Optional[str] = None, to_phase: Optional[str] = None - ) -> List[Dict[str, Any]]: + async def get_phase_history(self, from_phase: Optional[str] = None, to_phase: Optional[str] = None) -> List[Dict[str, Any]]: """ Get phase history for the game. @@ -305,9 +348,7 @@ class WebSocketDiplomacyClient: if not self.game: raise DiplomacyException("Must join a game first") - return await self.game.get_phase_history( - from_phase=from_phase, to_phase=to_phase - ) + return await self.game.get_phase_history(from_phase=from_phase, to_phase=to_phase) async def vote(self, power_name: str, vote: str) -> None: """ diff --git a/bot_client/websocket_negotiations.py b/bot_client/websocket_negotiations.py new file mode 100644 index 0000000..4eb8bf5 --- /dev/null +++ b/bot_client/websocket_negotiations.py @@ -0,0 +1,326 @@ +""" +WebSocket-specific negotiation logic for single bot players. + +This module provides negotiation capabilities for individual bots connected +via WebSocket, adapted from the multi-agent negotiation system in the main +ai_diplomacy package. +""" + +from typing import Dict, Optional, List +from loguru import logger + +from diplomacy.engine.message import GLOBAL + +from ai_diplomacy.utils import gather_possible_orders + +from ai_diplomacy.agent import DiplomacyAgent +from ai_diplomacy.game_history import GameHistory +from websocket_diplomacy_client import WebSocketDiplomacyClient + + +async def conduct_single_bot_negotiation( + client: "WebSocketDiplomacyClient", + agent: "DiplomacyAgent", + game_history: "GameHistory", + model_error_stats: Dict[str, Dict[str, int]], + log_file_path: str, + max_rounds: int = 3, + round_number: int = 1, + prioritize_targets: Optional[List[str]] = None, +) -> bool: + """ + Conduct negotiation for a single bot during one negotiation round. + + This function handles message generation and sending for one bot during + a negotiation phase. Unlike the multi-agent version, this focuses on + a single power and sends messages via WebSocket. + + Args: + client: WebSocket diplomacy client + agent: The bot's AI agent + game_history: Game history tracker + model_error_stats: Error statistics tracking + log_file_path: Path for logging + max_rounds: Maximum number of negotiation rounds + round_number: Current round number (1-indexed) + prioritize_targets: Optional list of powers to prioritize for messaging + + Returns: + True if messages were sent successfully, False otherwise + """ + power_name = agent.power_name + logger.info(f"Starting negotiation round {round_number}/{max_rounds} for {power_name}") + + # Check if this power is eliminated + if client.get_power(power_name).is_eliminated(): + logger.info(f"{power_name} is eliminated, skipping negotiation") + return False + + # Check if this power has any orderable locations + possible_orders = gather_possible_orders(client.game, power_name) + if not possible_orders: + logger.info(f"No orderable locations for {power_name}, skipping negotiation") + return False + + # Get active powers for context + active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated()] + + # Prioritize message targets if specified + message_targets = prioritize_targets if prioritize_targets else active_powers + message_targets = [p for p in message_targets if p in active_powers and p != power_name] + + # Generate conversation messages using the AI agent + board_state = client.get_state() + + messages = await agent.client.get_conversation_reply( + game=client.game, + board_state=board_state, + power_name=power_name, + possible_orders=possible_orders, + game_history=game_history, + game_phase=client.get_current_short_phase(), + log_file_path=log_file_path, + active_powers=active_powers, + agent_goals=agent.goals, + agent_relationships=agent.relationships, + agent_private_diary_str=agent.format_private_diary_for_prompt(), + ) + + if not messages: + logger.debug(f"No messages generated for {power_name} in round {round_number}") + return False + + # Process and send each message + messages_sent = 0 + for message in messages: + success = await _send_negotiation_message(client, agent, game_history, message, power_name) + if success: + messages_sent += 1 + + logger.info(f"Sent {messages_sent}/{len(messages)} messages for {power_name}") + return messages_sent > 0 + + +async def _send_negotiation_message( + client: WebSocketDiplomacyClient, + agent: DiplomacyAgent, + game_history: GameHistory, + message: Dict, + power_name: str, +) -> bool: + """ + Send a single negotiation message via WebSocket. + + Args: + client: WebSocket diplomacy client + agent: The bot's AI agent + game_history: Game history tracker + message: Message dictionary with content and metadata + power_name: Name of the sending power + + Returns: + True if message was sent successfully, False otherwise + """ + # Validate message structure + if not isinstance(message, dict) or "content" not in message: + logger.warning(f"Invalid message format from {power_name}: {message}") + return False + + content = message.get("content", "").strip() + if not content: + logger.debug(f"Empty message content from {power_name}, skipping") + return False + + # Determine recipient + recipient = GLOBAL # Default to global + if message.get("message_type") == "private": + recipient = message.get("recipient", GLOBAL) + # Validate recipient is a valid power + if recipient not in client.powers and recipient != GLOBAL: + logger.warning(f"Invalid recipient '{recipient}' from {power_name}, sending globally") + recipient = GLOBAL + + # Send the message via WebSocket + await client.send_message( + sender=power_name, + recipient=recipient, + message=content, + phase=client.get_current_short_phase(), + ) + + # Add to game history + game_history.add_message( + phase_name=client.get_current_short_phase(), + sender=power_name, + recipient=recipient, + message_content=content, + ) + + # Add to agent's journal + journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally" + agent.add_journal_entry(f"Sent message {journal_recipient} in {client.get_current_short_phase()}: {content[:100]}...") + + logger.info(f"[{power_name} -> {recipient}] {content[:100]}...") + return True + + +async def should_participate_in_negotiations( + client: "WebSocketDiplomacyClient", + agent: "DiplomacyAgent", +) -> bool: + """ + Determine if this bot should participate in negotiations. + + Args: + client: WebSocket diplomacy client + agent: The bot's AI agent + + Returns: + True if the bot should participate in negotiations + """ + power_name = agent.power_name + + # Don't negotiate if eliminated + if client.get_power(power_name).is_eliminated(): + return False + + # Don't negotiate if no orderable locations + possible_orders = gather_possible_orders(client.game, power_name) + if not possible_orders: + return False + + # Only negotiate during movement phases + current_phase = client.get_current_short_phase() + if not current_phase.endswith("M"): + return False + + return True + + +def get_negotiation_delay(round_number: int, total_rounds: int) -> float: + """ + Calculate delay between negotiation rounds to allow message processing. + + Args: + round_number: Current round number (1-indexed) + total_rounds: Total number of rounds + + Returns: + Delay in seconds + """ + # Longer delay in early rounds to allow more strategic messaging + base_delay = 10.0 # Base delay between rounds + + if round_number == 1: + return base_delay * 1.5 # Extra time for first round + elif round_number == total_rounds: + return base_delay * 0.5 # Less time for final round + else: + return base_delay + + +async def analyze_recent_messages_for_targeting( + client: "WebSocketDiplomacyClient", + power_name: str, + max_messages: int = 20, +) -> List[str]: + """ + Analyze recent messages to identify which powers should be prioritized for negotiations. + + Args: + client: WebSocket diplomacy client + power_name: Name of the analyzing power + max_messages: Maximum number of recent messages to analyze + + Returns: + List of power names in order of priority for messaging + """ + # Get recent messages from current phase + recent_messages = await client.get_recent_messages(limit=max_messages) + + # Track who has been active and who has messaged us + message_activity = {} + direct_messages_to_us = {} + + for message in recent_messages: + sender = message.sender + recipient = message.recipient + + # Track general activity + if sender != power_name: + message_activity[sender] = message_activity.get(sender, 0) + 1 + + # Track direct messages to us + if recipient == power_name and sender != power_name: + direct_messages_to_us[sender] = direct_messages_to_us.get(sender, 0) + 1 + + # Get all active powers + active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated() and p_name != power_name] + + # Prioritize based on: 1) Powers that messaged us directly, 2) Most active powers + priority_list = [] + + # First, add powers that sent us direct messages (sorted by count) + direct_senders = sorted(direct_messages_to_us.items(), key=lambda x: x[1], reverse=True) + for sender, _ in direct_senders: + if sender in active_powers: + priority_list.append(sender) + + # Then add other active powers (sorted by activity) + remaining_powers = [p for p in active_powers if p not in priority_list] + activity_sorted = sorted(remaining_powers, key=lambda p: message_activity.get(p, 0), reverse=True) + priority_list.extend(activity_sorted) + + logger.debug(f"Message targeting priority for {power_name}: {priority_list}") + return priority_list + + +async def conduct_strategic_negotiation_round( + client: "WebSocketDiplomacyClient", + agent: "DiplomacyAgent", + game_history: "GameHistory", + model_error_stats: Dict[str, Dict[str, int]], + log_file_path: str, + round_number: int, + max_rounds: int = 3, +) -> bool: + """ + Conduct a single negotiation round with strategic message targeting. + + This function analyzes recent message activity to determine which powers + to prioritize for messaging in this round. + + Args: + client: WebSocket diplomacy client + agent: The bot's AI agent + game_history: Game history tracker + model_error_stats: Error statistics tracking + log_file_path: Path for logging + round_number: Current round number (1-indexed) + max_rounds: Maximum number of negotiation rounds + + Returns: + True if messages were sent successfully, False otherwise + """ + power_name = agent.power_name + + # Analyze recent messages to prioritize targets + priority_targets = await analyze_recent_messages_for_targeting(client, power_name) + + # Limit to top 3-4 targets in later rounds to focus conversations + if round_number > 1: + priority_targets = priority_targets[: min(4, len(priority_targets))] + + logger.info(f"Round {round_number} targets for {power_name}: {priority_targets}") + + # Conduct negotiation with prioritized targets + return await conduct_single_bot_negotiation( + client=client, + agent=agent, + game_history=game_history, + model_error_stats=model_error_stats, + log_file_path=log_file_path, + max_rounds=max_rounds, + round_number=round_number, + prioritize_targets=priority_targets, + )