WIP: Exposing errors that cause bot to fail

Working to remove all the lower level try except statements that catch
barely, and don't offer any information to top level about what went
wrong. Better to throw exceptions and let the consumers catch if needed.
This commit is contained in:
Tyler Marques 2025-06-24 19:12:22 -07:00
parent f37c344213
commit c09d5bfe7e
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 1013 additions and 400 deletions

3
.vscode/launch.json vendored
View file

@ -105,7 +105,7 @@
"--game-id", "--game-id",
"${input:gameId}", "${input:gameId}",
"--log-level", "--log-level",
"INFO" "DEBUG"
], ],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true, "justMyCode": true,
@ -138,6 +138,7 @@
"type": "pickString", "type": "pickString",
"description": "Select the AI model to use", "description": "Select the AI model to use",
"options": [ "options": [
"gemini-2.5-flash-lite-preview-06-17",
"gpt-3.5-turbo", "gpt-3.5-turbo",
"gpt-4", "gpt-4",
"claude-3-haiku", "claude-3-haiku",

View file

@ -5,6 +5,7 @@ import re
import json_repair import json_repair
import json5 import json5
from ai_diplomacy.game_history import GameHistory
from diplomacy.engine.game import Game # More forgiving JSON parser from diplomacy.engine.game import Game # More forgiving JSON parser
# Assuming BaseModelClient is importable from clients.py in the same directory # Assuming BaseModelClient is importable from clients.py in the same directory
@ -1150,6 +1151,9 @@ class DiplomacyAgent:
game_history: "GameHistory", game_history: "GameHistory",
log_file_path: str | Path, log_file_path: str | Path,
): ):
assert GameHistory is not None, (
"GameHistory object must be initialized before calling analyze_phase_and_update_state"
)
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM.""" """Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
# Use self.power_name internally # Use self.power_name internally
power_name = self.power_name power_name = self.power_name
@ -1169,7 +1173,7 @@ class DiplomacyAgent:
return return
# Get previous phase safely from history # Get previous phase safely from history
if not game_history or not game_history.phases: if not game_history.phases:
logger.warning( logger.warning(
f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update." f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update."
) )
@ -1254,9 +1258,8 @@ class DiplomacyAgent:
log_entry_success = "FALSE" # Default log_entry_success = "FALSE" # Default
update_data = None # Initialize update_data = None # Initialize
if ( # Check if response is not None and not just whitespace
response is not None and response.strip() if response is not None and response.strip():
): # Check if response is not None and not just whitespace
try: try:
update_data = self._extract_json_from_text(response) update_data = self._extract_json_from_text(response)
logger.debug( logger.debug(

View file

@ -15,7 +15,6 @@ from anthropic import AsyncAnthropic
import google.generativeai as genai import google.generativeai as genai
from diplomacy.engine.message import GLOBAL
from .game_history import GameHistory from .game_history import GameHistory
from .utils import ( from .utils import (
load_prompt, load_prompt,
@ -25,7 +24,6 @@ from .utils import (
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible # Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
# from .agent import DiplomacyAgent # from .agent import DiplomacyAgent
from .possible_order_context import generate_rich_order_context
from .prompt_constructor import ( from .prompt_constructor import (
construct_order_generation_prompt, construct_order_generation_prompt,
build_context_prompt, build_context_prompt,
@ -113,7 +111,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log( raw_response = await run_llm_and_log(
client=self, client=self,
prompt=prompt, prompt=prompt,
log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=phase, phase=phase,
response_type="order", # Context for run_llm_and_log's own error logging response_type="order", # Context for run_llm_and_log's own error logging
@ -553,7 +550,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log( raw_response = await run_llm_and_log(
client=self, client=self,
prompt=prompt, prompt=prompt,
log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game_phase, # Use game_phase for logging phase=game_phase, # Use game_phase for logging
response_type="plan_reply", # Changed from 'plan' to avoid confusion response_type="plan_reply", # Changed from 'plan' to avoid confusion
@ -604,7 +600,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log( raw_response = await run_llm_and_log(
client=self, client=self,
prompt=raw_input_prompt, prompt=raw_input_prompt,
log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game_phase, phase=game_phase,
response_type="negotiation", # For run_llm_and_log's internal context response_type="negotiation", # For run_llm_and_log's internal context
@ -829,7 +824,6 @@ class BaseModelClient:
raw_plan_response = await run_llm_and_log( raw_plan_response = await run_llm_and_log(
client=self, # Pass self (the client instance) client=self, # Pass self (the client instance)
prompt=full_prompt, prompt=full_prompt,
log_file_path=log_file_path,
power_name=power_name, power_name=power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type="plan_generation", # More specific type for run_llm_and_log context response_type="plan_generation", # More specific type for run_llm_and_log context
@ -1328,4 +1322,3 @@ def get_visible_messages_for_power(conversation_messages, power_name):
): ):
visible.append(msg) visible.append(msg)
return visible # already in chronological order if appended that way return visible # already in chronological order if appended that way

View file

@ -3,7 +3,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from .utils import run_llm_and_log, log_llm_response, load_prompt from .utils import run_llm_and_log, log_llm_response
if TYPE_CHECKING: if TYPE_CHECKING:
from diplomacy import Game from diplomacy import Game
@ -12,6 +12,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation( async def run_diary_consolidation(
agent: "DiplomacyAgent", agent: "DiplomacyAgent",
game: "Game", game: "Game",
@ -23,19 +42,32 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones. Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class. This is the logic moved from the DiplomacyAgent class.
""" """
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries") logger.info(
f"[{agent.power_name}] CONSOLIDATION START — "
f"{len(agent.full_private_diary)} total full entries"
)
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")] full_entries = [
e
for e in agent.full_private_diary
if not e.startswith("[CONSOLIDATED HISTORY]")
]
if len(full_entries) <= entries_to_keep_unsummarized: if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation") logger.info(
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
"skipping consolidation"
)
return return
boundary_entry = full_entries[-entries_to_keep_unsummarized] boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry) match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match: if not match:
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation") logger.error(
f"[{agent.power_name}] Could not parse year from boundary entry; "
"aborting consolidation"
)
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
return return
@ -46,19 +78,37 @@ async def run_diary_consolidation(
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry) m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None return int(m.group(1)) if m else None
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)] entries_to_summarize = [
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)] e
for e in full_entries
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
]
entries_to_keep = [
e
for e in full_entries
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
]
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim") logger.info(
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
f"keeping {len(entries_to_keep)} recent entries verbatim"
)
if not entries_to_summarize: if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary) agent.private_diary = list(agent.full_private_diary)
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged") logger.warning(
f"[{agent.power_name}] No eligible entries to summarise; "
"context diary left unchanged"
)
return return
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir) prompt_template = _load_prompt_file(
"diary_consolidation_prompt.txt", prompts_dir=prompts_dir
)
if not prompt_template: if not prompt_template:
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting") logger.error(
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
)
return return
prompt = prompt_template.format( prompt = prompt_template.format(
@ -87,14 +137,23 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}" new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE" success_flag = "TRUE"
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now") logger.info(
f"[{agent.power_name}] Consolidation complete — "
f"{len(agent.private_diary)} context entries now"
)
except Exception as exc: except Exception as exc:
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True) logger.error(
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
)
finally: finally:
log_llm_response( log_llm_response(
log_file_path=log_file_path, log_file_path=log_file_path,
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name), model_name=(
consolidation_client.model_name
if consolidation_client is not None
else agent.client.model_name
),
power_name=agent.power_name, power_name=agent.power_name,
phase=game.current_short_phase, phase=game.current_short_phase,
response_type="diary_consolidation", response_type="diary_consolidation",
@ -102,3 +161,4 @@ async def run_diary_consolidation(
raw_response=raw_response, raw_response=raw_response,
success=success_flag, success=success_flag,
) )

View file

@ -148,7 +148,6 @@ class GameHistory:
return {} return {}
return self.phases[-1].plans return self.phases[-1].plans
# NEW METHOD
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str: def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None current_phase: Optional[Phase] = None
for phase_obj in self.phases: for phase_obj in self.phases:

View file

@ -1,4 +1,4 @@
from os.path import exists import datetime
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from pathlib import Path from pathlib import Path
import warnings import warnings
@ -6,16 +6,16 @@ import warnings
class Configuration(BaseSettings): class Configuration(BaseSettings):
DEBUG: bool = False DEBUG: bool = False
log_file_path: Path = Path("./logs/logs.txt") log_file_path: Path
DEEPSEEK_API_KEY: str | None = None DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None OPENROUTER_API_KEY: str | None = None
def __init__(self, **kwargs): def __init__(self, power_name, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
# Make the path absolute, gets rid of weirdness of calling this in different places # Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve() self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True) self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
@ -42,11 +42,7 @@ class Configuration(BaseSettings):
"""Override to check for empty API keys at access time""" """Override to check for empty API keys at access time"""
value = super().__getattribute__(name) value = super().__getattribute__(name)
if name.endswith("_KEY") and ( if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
not value or (isinstance(value, str) and len(value) == 0) raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
):
raise ValueError(
f"API key '{name}' is not set or is empty. Please configure it before use."
)
return value return value

View file

@ -30,12 +30,9 @@ from websocket_diplomacy_client import (
from models import ( from models import (
SignInRequest, SignInRequest,
CreateGameRequest, CreateGameRequest,
JoinGameRequest,
SetOrdersRequest, SetOrdersRequest,
ListGamesRequest, ListGamesRequest,
ProcessGameRequest,
serialize_message, serialize_message,
parse_message,
) )
import sys import sys
@ -64,9 +61,7 @@ logging.getLogger("root").setLevel(logging.WARNING)
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run a Diplomacy game simulation with WebSocket server connection.")
description="Run a Diplomacy game simulation with WebSocket server connection."
)
parser.add_argument( parser.add_argument(
"--hostname", "--hostname",
type=str, type=str,
@ -137,9 +132,7 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
async def join_powers_for_testing( async def join_powers_for_testing(client: WebSocketDiplomacyClient, power_model_map: dict):
client: WebSocketDiplomacyClient, power_model_map: dict
):
""" """
Join multiple powers in the same game for testing purposes. Join multiple powers in the same game for testing purposes.
This simulates having multiple AI players in one game. This simulates having multiple AI players in one game.
@ -150,17 +143,13 @@ async def join_powers_for_testing(
for power_name in power_names[1:]: for power_name in power_names[1:]:
try: try:
logger.info(f"Attempting to join power {power_name}") logger.info(f"Attempting to join power {power_name}")
await client.channel.join_game( await client.channel.join_game(game_id=client.game_id, power_name=power_name)
game_id=client.game_id, power_name=power_name
)
logger.info(f"Successfully joined {power_name}") logger.info(f"Successfully joined {power_name}")
except Exception as e: except Exception as e:
logger.warning(f"Could not join {power_name}: {e}") logger.warning(f"Could not join {power_name}: {e}")
async def create_or_join_game( async def create_or_join_game(client: WebSocketDiplomacyClient, args, power_model_map: dict):
client: WebSocketDiplomacyClient, args, power_model_map: dict
):
""" """
Create a new game or join an existing one based on arguments. Create a new game or join an existing one based on arguments.
""" """
@ -171,9 +160,7 @@ async def create_or_join_game(
# List available games first to see what's available # List available games first to see what's available
try: try:
games = await client.list_games() games = await client.list_games()
logger.info( logger.info(f"Available games: {[g.get('game_id', 'unknown') for g in games]}")
f"Available games: {[g.get('game_id', 'unknown') for g in games]}"
)
except Exception as e: except Exception as e:
logger.warning(f"Could not list games: {e}") logger.warning(f"Could not list games: {e}")
@ -189,19 +176,13 @@ async def create_or_join_game(
logger.info("Creating new game") logger.info("Creating new game")
# Get the first power to control # Get the first power to control
first_power = ( first_power = list(power_model_map.keys())[0] if not args.create_multi_power_game else None
list(power_model_map.keys())[0]
if not args.create_multi_power_game
else None
)
game = await client.create_game( game = await client.create_game(
map_name="standard", map_name="standard",
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"], rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
power_name=first_power, power_name=first_power,
n_controls=7 n_controls=7 if not args.create_multi_power_game else 1, # Lower requirement for testing
if not args.create_multi_power_game
else 1, # Lower requirement for testing
deadline=None, # No time pressure for AI testing deadline=None, # No time pressure for AI testing
) )
@ -223,11 +204,7 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
logger.info("=== Typed Message Examples ===") logger.info("=== Typed Message Examples ===")
# Example 1: Authentication Request # Example 1: Authentication Request
sign_in_request = SignInRequest( sign_in_request = SignInRequest(request_id=str(uuid.uuid4()), username=args.username, password=args.password)
request_id=str(uuid.uuid4()),
username=args.username,
password=args.password
)
logger.info(f"SignIn request: {serialize_message(sign_in_request)}") logger.info(f"SignIn request: {serialize_message(sign_in_request)}")
# Example 2: Create Game Request # Example 2: Create Game Request
@ -237,26 +214,22 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
map_name="standard", map_name="standard",
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"], rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
n_controls=1, n_controls=1,
power_name="FRANCE" power_name="FRANCE",
) )
logger.info(f"CreateGame request: {serialize_message(create_game_request)}") logger.info(f"CreateGame request: {serialize_message(create_game_request)}")
# Example 3: List Games Request # Example 3: List Games Request
list_games_request = ListGamesRequest( list_games_request = ListGamesRequest(request_id=str(uuid.uuid4()), token="example_auth_token_here", include_protected=False)
request_id=str(uuid.uuid4()),
token="example_auth_token_here",
include_protected=False
)
logger.info(f"ListGames request: {serialize_message(list_games_request)}") logger.info(f"ListGames request: {serialize_message(list_games_request)}")
# Example 4: Set Orders Request (would be used in game loop) # Example 4: Set Orders Request (would be used in game loop)
if hasattr(client, 'game_id') and hasattr(client, 'token'): if hasattr(client, "game_id") and hasattr(client, "token"):
set_orders_request = SetOrdersRequest( set_orders_request = SetOrdersRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
token=client.token or "example_token", token=client.token or "example_token",
game_id=client.game_id or "example_game_id", game_id=client.game_id or "example_game_id",
game_role="FRANCE", game_role="FRANCE",
orders=["A PAR H", "F BRE H", "A MAR H"] orders=["A PAR H", "F BRE H", "A MAR H"],
) )
logger.info(f"SetOrders request: {serialize_message(set_orders_request)}") logger.info(f"SetOrders request: {serialize_message(set_orders_request)}")
@ -282,9 +255,7 @@ async def main():
logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs") logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs")
start_whole = time.time() start_whole = time.time()
model_error_stats = defaultdict( model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
)
# Determine the result folder based on a timestamp # Determine the result folder based on a timestamp
timestamp_str = time.strftime("%Y%m%d_%H%M%S") timestamp_str = time.strftime("%Y%m%d_%H%M%S")
@ -306,9 +277,7 @@ async def main():
# File paths # File paths
manifesto_path = f"{result_folder}/game_manifesto.txt" manifesto_path = f"{result_folder}/game_manifesto.txt"
game_file_path = ( game_file_path = args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
)
overview_file_path = f"{result_folder}/overview.jsonl" overview_file_path = f"{result_folder}/overview.jsonl"
llm_log_file_path = f"{result_folder}/llm_responses.csv" llm_log_file_path = f"{result_folder}/llm_responses.csv"
@ -325,9 +294,7 @@ async def main():
] ]
provided_models = [name.strip() for name in args.models.split(",")] provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) != len(powers_order): if len(provided_models) != len(powers_order):
logger.error( logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting.")
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting."
)
return return
power_model_map = dict(zip(powers_order, provided_models)) power_model_map = dict(zip(powers_order, provided_models))
else: else:
@ -377,32 +344,20 @@ async def main():
client_obj = load_model_client(model_id) client_obj = load_model_client(model_id)
agent = DiplomacyAgent(power_name=power_name, client=client_obj) agent = DiplomacyAgent(power_name=power_name, client=client_obj)
agents[power_name] = agent agents[power_name] = agent
logger.info( logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
f"Preparing initialization task for {power_name} with model {model_id}" initialization_tasks.append(initialize_agent_state_ext(agent, client.game, game_history, llm_log_file_path))
)
initialization_tasks.append(
initialize_agent_state_ext(
agent, client.game, game_history, llm_log_file_path
)
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to create agent for {power_name} with model {model_id}: {e}", f"Failed to create agent for {power_name} with model {model_id}: {e}",
exc_info=True, exc_info=True,
) )
else: else:
logger.info( logger.info(f"Skipping agent initialization for {power_name} (no model or eliminated)")
f"Skipping agent initialization for {power_name} (no model or eliminated)"
)
# Run initializations concurrently # Run initializations concurrently
if initialization_tasks: if initialization_tasks:
logger.info( logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
f"Running {len(initialization_tasks)} agent initializations concurrently..." initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
)
initialization_results = await asyncio.gather(
*initialization_tasks, return_exceptions=True
)
initialized_powers = list(agents.keys()) initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results): for i, result in enumerate(initialization_results):
@ -414,9 +369,7 @@ async def main():
exc_info=result, exc_info=result,
) )
else: else:
logger.info( logger.info(f"Successfully initialized agent state for {power_name}.")
f"Successfully initialized agent state for {power_name}."
)
# Main game loop # Main game loop
all_phase_relationships = {} all_phase_relationships = {}
@ -433,9 +386,7 @@ async def main():
game_history.add_phase(current_phase) game_history.add_phase(current_phase)
current_short_phase = client.get_current_short_phase() current_short_phase = client.get_current_short_phase()
logger.info( logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
)
# Prevent unbounded simulation based on year # Prevent unbounded simulation based on year
year_str = current_phase[1:5] year_str = current_phase[1:5]
@ -447,9 +398,7 @@ async def main():
# Negotiations for movement phases # Negotiations for movement phases
if client.get_current_short_phase().endswith("M"): if client.get_current_short_phase().endswith("M"):
if args.num_negotiation_rounds > 0: if args.num_negotiation_rounds > 0:
logger.info( logger.info(f"Running {args.num_negotiation_rounds} rounds of negotiations...")
f"Running {args.num_negotiation_rounds} rounds of negotiations..."
)
game_history = await conduct_negotiations( game_history = await conduct_negotiations(
client.game, # Pass the NetworkGame object client.game, # Pass the NetworkGame object
agents, agents,
@ -459,9 +408,7 @@ async def main():
log_file_path=llm_log_file_path, log_file_path=llm_log_file_path,
) )
else: else:
logger.info( logger.info("Skipping negotiation phase as num_negotiation_rounds=0")
"Skipping negotiation phase as num_negotiation_rounds=0"
)
# Planning phase (if enabled) # Planning phase (if enabled)
if args.planning_phase: if args.planning_phase:
@ -475,29 +422,19 @@ async def main():
) )
# Generate negotiation diary entries # Generate negotiation diary entries
logger.info( logger.info(f"Generating negotiation diary entries for phase {current_short_phase}...")
f"Generating negotiation diary entries for phase {current_short_phase}..." active_powers_for_neg_diary = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
)
active_powers_for_neg_diary = [
p for p in agents.keys() if not client.get_power(p).is_eliminated()
]
neg_diary_tasks = [] neg_diary_tasks = []
for power_name, agent in agents.items(): for power_name, agent in agents.items():
if not client.get_power(power_name).is_eliminated(): if not client.get_power(power_name).is_eliminated():
neg_diary_tasks.append( neg_diary_tasks.append(agent.generate_negotiation_diary_entry(client.game, game_history, llm_log_file_path))
agent.generate_negotiation_diary_entry(
client.game, game_history, llm_log_file_path
)
)
if neg_diary_tasks: if neg_diary_tasks:
await asyncio.gather(*neg_diary_tasks, return_exceptions=True) await asyncio.gather(*neg_diary_tasks, return_exceptions=True)
# AI Decision Making: Get orders for each controlled power # AI Decision Making: Get orders for each controlled power
logger.info("Getting orders from agents...") logger.info("Getting orders from agents...")
active_powers_for_orders = [ active_powers_for_orders = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
p for p in agents.keys() if not client.get_power(p).is_eliminated()
]
order_tasks = [] order_tasks = []
order_power_names = [] order_power_names = []
@ -505,35 +442,23 @@ async def main():
for power_name, agent in agents.items(): for power_name, agent in agents.items():
if client.get_power(power_name).is_eliminated(): if client.get_power(power_name).is_eliminated():
logger.debug( logger.debug(f"Skipping order generation for eliminated power {power_name}.")
f"Skipping order generation for eliminated power {power_name}."
)
continue continue
# Diagnostic logging # Diagnostic logging
logger.info( logger.info(f"--- Diagnostic Log for {power_name} in phase {current_phase} ---")
f"--- Diagnostic Log for {power_name} in phase {current_phase} ---"
)
try: try:
orderable_locs = client.get_orderable_locations(power_name) orderable_locs = client.get_orderable_locations(power_name)
logger.info( logger.info(f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}")
f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}"
)
actual_units = client.get_units(power_name) actual_units = client.get_units(power_name)
logger.info( logger.info(f"[{power_name}][{current_phase}] Actual units: {actual_units}")
f"[{power_name}][{current_phase}] Actual units: {actual_units}"
)
except Exception as e_diag: except Exception as e_diag:
logger.error( logger.error(f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}")
f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}"
)
# Calculate possible orders # Calculate possible orders
possible_orders = gather_possible_orders(client.game, power_name) possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders: if not possible_orders:
logger.debug( logger.debug(f"No orderable locations for {power_name}; submitting empty orders.")
f"No orderable locations for {power_name}; submitting empty orders."
)
await client.set_orders(power_name, []) await client.set_orders(power_name, [])
continue continue
@ -559,12 +484,8 @@ async def main():
# Run order generation concurrently # Run order generation concurrently
if order_tasks: if order_tasks:
logger.debug( logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...")
f"Running {len(order_tasks)} order generation tasks concurrently..." order_results = await asyncio.gather(*order_tasks, return_exceptions=True)
)
order_results = await asyncio.gather(
*order_tasks, return_exceptions=True
)
else: else:
order_results = [] order_results = []
@ -580,9 +501,7 @@ async def main():
) )
await client.set_orders(p_name, []) await client.set_orders(p_name, [])
elif result is None: elif result is None:
logger.warning( logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.")
f"get_valid_orders returned None for {p_name}. Setting empty orders."
)
await client.set_orders(p_name, []) await client.set_orders(p_name, [])
else: else:
orders = result orders = result
@ -600,15 +519,11 @@ async def main():
# await client.send_typed_message(set_orders_msg) # await client.send_typed_message(set_orders_msg)
await client.set_orders(p_name, orders) await client.set_orders(p_name, orders)
logger.debug( logger.debug(f"Set orders for {p_name} in {current_short_phase}: {orders}")
f"Set orders for {p_name} in {current_short_phase}: {orders}"
)
# Generate order diary entry # Generate order diary entry
try: try:
await agent.generate_order_diary_entry( await agent.generate_order_diary_entry(client.game, orders, llm_log_file_path)
client.game, orders, llm_log_file_path
)
except Exception as e_diary: except Exception as e_diary:
logger.error( logger.error(
f"Error generating order diary for {p_name}: {e_diary}", f"Error generating order diary for {p_name}: {e_diary}",
@ -650,22 +565,13 @@ async def main():
# Collect relationships for this phase # Collect relationships for this phase
current_relationships_for_phase = {} current_relationships_for_phase = {}
for power_name, agent in agents.items(): for power_name, agent in agents.items():
if ( if power_name in client.powers and not client.get_power(power_name).is_eliminated():
power_name in client.powers
and not client.get_power(power_name).is_eliminated()
):
current_relationships_for_phase[power_name] = agent.relationships current_relationships_for_phase[power_name] = agent.relationships
all_phase_relationships[current_short_phase] = ( all_phase_relationships[current_short_phase] = current_relationships_for_phase
current_relationships_for_phase
)
# Generate phase result diary entries # Generate phase result diary entries
logger.info( logger.info(f"Generating phase result diary entries for completed phase {current_phase}...")
f"Generating phase result diary entries for completed phase {current_phase}..." phase_summary = getattr(client.game, "phase_summaries", {}).get(current_phase, "(Summary not generated)")
)
phase_summary = getattr(client.game, "phase_summaries", {}).get(
current_phase, "(Summary not generated)"
)
all_orders_this_phase = current_order_history all_orders_this_phase = current_order_history
phase_result_diary_tasks = [] phase_result_diary_tasks = []
@ -685,16 +591,10 @@ async def main():
await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True) await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True)
# State update analysis # State update analysis
logger.info( logger.info(f"Starting state update analysis for completed phase {current_phase}...")
f"Starting state update analysis for completed phase {current_phase}..."
)
current_board_state = client.get_state() current_board_state = client.get_state()
active_agent_powers = [ active_agent_powers = [(p, power) for p, power in client.powers.items() if p in agents and not power.is_eliminated()]
(p, power)
for p, power in client.powers.items()
if p in agents and not power.is_eliminated()
]
if active_agent_powers: if active_agent_powers:
state_update_tasks = [] state_update_tasks = []
@ -781,4 +681,3 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -8,15 +8,17 @@ and notifications (server -> client).
from typing import Optional, List, Dict, Any, Union, Literal from typing import Optional, List, Dict, Any, Union, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from abc import ABC, abstractmethod from abc import ABC
# ============================================================================= # =============================================================================
# Base Message Classes # Base Message Classes
# ============================================================================= # =============================================================================
class BaseMessage(BaseModel, ABC): class BaseMessage(BaseModel, ABC):
"""Base class for all WebSocket messages.""" """Base class for all WebSocket messages."""
name: str name: str
class Config: class Config:
@ -25,17 +27,20 @@ class BaseMessage(BaseModel, ABC):
class BaseRequest(BaseMessage): class BaseRequest(BaseMessage):
"""Base class for all client -> server requests.""" """Base class for all client -> server requests."""
request_id: str request_id: str
re_sent: bool = False re_sent: bool = False
class BaseResponse(BaseMessage): class BaseResponse(BaseMessage):
"""Base class for all server -> client responses.""" """Base class for all server -> client responses."""
request_id: str request_id: str
class BaseNotification(BaseMessage): class BaseNotification(BaseMessage):
"""Base class for all server -> client notifications (no request_id).""" """Base class for all server -> client notifications (no request_id)."""
pass pass
@ -43,8 +48,10 @@ class BaseNotification(BaseMessage):
# Authentication & Connection Level Messages # Authentication & Connection Level Messages
# ============================================================================= # =============================================================================
class SignInRequest(BaseRequest): class SignInRequest(BaseRequest):
"""Client authentication request.""" """Client authentication request."""
name: Literal["sign_in"] = "sign_in" name: Literal["sign_in"] = "sign_in"
username: str username: str
password: str password: str
@ -52,6 +59,7 @@ class SignInRequest(BaseRequest):
class GetDaidePortRequest(BaseRequest): class GetDaidePortRequest(BaseRequest):
"""Request DAIDE TCP port for a game.""" """Request DAIDE TCP port for a game."""
name: Literal["get_daide_port"] = "get_daide_port" name: Literal["get_daide_port"] = "get_daide_port"
game_id: str game_id: str
@ -60,13 +68,16 @@ class GetDaidePortRequest(BaseRequest):
# Channel Level Messages (require authentication token) # Channel Level Messages (require authentication token)
# ============================================================================= # =============================================================================
class ChannelRequest(BaseRequest): class ChannelRequest(BaseRequest):
"""Base class for channel-level requests that require authentication.""" """Base class for channel-level requests that require authentication."""
token: str token: str
class CreateGameRequest(ChannelRequest): class CreateGameRequest(ChannelRequest):
"""Create a new game.""" """Create a new game."""
name: Literal["create_game"] = "create_game" name: Literal["create_game"] = "create_game"
map_name: str = "standard" map_name: str = "standard"
rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"]) rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"])
@ -78,6 +89,7 @@ class CreateGameRequest(ChannelRequest):
class JoinGameRequest(ChannelRequest): class JoinGameRequest(ChannelRequest):
"""Join an existing game.""" """Join an existing game."""
name: Literal["join_game"] = "join_game" name: Literal["join_game"] = "join_game"
game_id: str game_id: str
power_name: Optional[str] = None power_name: Optional[str] = None
@ -86,6 +98,7 @@ class JoinGameRequest(ChannelRequest):
class JoinPowersRequest(ChannelRequest): class JoinPowersRequest(ChannelRequest):
"""Join multiple powers in a game.""" """Join multiple powers in a game."""
name: Literal["join_powers"] = "join_powers" name: Literal["join_powers"] = "join_powers"
game_id: str game_id: str
power_names: List[str] power_names: List[str]
@ -94,6 +107,7 @@ class JoinPowersRequest(ChannelRequest):
class ListGamesRequest(ChannelRequest): class ListGamesRequest(ChannelRequest):
"""List available games.""" """List available games."""
name: Literal["list_games"] = "list_games" name: Literal["list_games"] = "list_games"
game_id_filter: Optional[str] = None game_id_filter: Optional[str] = None
map_name: Optional[str] = None map_name: Optional[str] = None
@ -103,23 +117,27 @@ class ListGamesRequest(ChannelRequest):
class GetPlayablePowersRequest(ChannelRequest): class GetPlayablePowersRequest(ChannelRequest):
"""Get uncontrolled powers in a game.""" """Get uncontrolled powers in a game."""
name: Literal["get_playable_powers"] = "get_playable_powers" name: Literal["get_playable_powers"] = "get_playable_powers"
game_id: str game_id: str
class GetAvailableMapsRequest(ChannelRequest): class GetAvailableMapsRequest(ChannelRequest):
"""Get list of available maps.""" """Get list of available maps."""
name: Literal["get_available_maps"] = "get_available_maps" name: Literal["get_available_maps"] = "get_available_maps"
class GetDummyWaitingPowersRequest(ChannelRequest): class GetDummyWaitingPowersRequest(ChannelRequest):
"""Get AI-controllable powers (bot use).""" """Get AI-controllable powers (bot use)."""
name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers" name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers"
game_id: str game_id: str
class SetGradeRequest(ChannelRequest): class SetGradeRequest(ChannelRequest):
"""Modify user permissions.""" """Modify user permissions."""
name: Literal["set_grade"] = "set_grade" name: Literal["set_grade"] = "set_grade"
username: str username: str
grade: str grade: str
@ -127,11 +145,13 @@ class SetGradeRequest(ChannelRequest):
class DeleteAccountRequest(ChannelRequest): class DeleteAccountRequest(ChannelRequest):
"""Delete user account.""" """Delete user account."""
name: Literal["delete_account"] = "delete_account" name: Literal["delete_account"] = "delete_account"
class LogoutRequest(ChannelRequest): class LogoutRequest(ChannelRequest):
"""Disconnect from server.""" """Disconnect from server."""
name: Literal["logout"] = "logout" name: Literal["logout"] = "logout"
@ -139,8 +159,10 @@ class LogoutRequest(ChannelRequest):
# Game Level Messages (require authentication + game context) # Game Level Messages (require authentication + game context)
# ============================================================================= # =============================================================================
class GameRequest(ChannelRequest): class GameRequest(ChannelRequest):
"""Base class for game-level requests.""" """Base class for game-level requests."""
game_id: str game_id: str
game_role: str # Power name like "ENGLAND" game_role: str # Power name like "ENGLAND"
phase: Optional[str] = None phase: Optional[str] = None
@ -148,18 +170,21 @@ class GameRequest(ChannelRequest):
class SetOrdersRequest(GameRequest): class SetOrdersRequest(GameRequest):
"""Submit orders for a power.""" """Submit orders for a power."""
name: Literal["set_orders"] = "set_orders" name: Literal["set_orders"] = "set_orders"
orders: List[str] orders: List[str]
class SetWaitFlagRequest(GameRequest): class SetWaitFlagRequest(GameRequest):
"""Set wait flag for turn processing.""" """Set wait flag for turn processing."""
name: Literal["set_wait_flag"] = "set_wait_flag" name: Literal["set_wait_flag"] = "set_wait_flag"
wait: bool wait: bool
class SendGameMessageRequest(GameRequest): class SendGameMessageRequest(GameRequest):
"""Send diplomatic message.""" """Send diplomatic message."""
name: Literal["send_game_message"] = "send_game_message" name: Literal["send_game_message"] = "send_game_message"
recipient: str # Power name or "GLOBAL" recipient: str # Power name or "GLOBAL"
message: str message: str
@ -168,11 +193,13 @@ class SendGameMessageRequest(GameRequest):
class GetAllPossibleOrdersRequest(GameRequest): class GetAllPossibleOrdersRequest(GameRequest):
"""Get legal orders for current phase.""" """Get legal orders for current phase."""
name: Literal["get_all_possible_orders"] = "get_all_possible_orders" name: Literal["get_all_possible_orders"] = "get_all_possible_orders"
class GetPhaseHistoryRequest(GameRequest): class GetPhaseHistoryRequest(GameRequest):
"""Get historical game phases.""" """Get historical game phases."""
name: Literal["get_phase_history"] = "get_phase_history" name: Literal["get_phase_history"] = "get_phase_history"
from_phase: Optional[str] = None from_phase: Optional[str] = None
to_phase: Optional[str] = None to_phase: Optional[str] = None
@ -180,45 +207,53 @@ class GetPhaseHistoryRequest(GameRequest):
class ProcessGameRequest(GameRequest): class ProcessGameRequest(GameRequest):
"""Force game processing (master only).""" """Force game processing (master only)."""
name: Literal["process_game"] = "process_game" name: Literal["process_game"] = "process_game"
class VoteRequest(GameRequest): class VoteRequest(GameRequest):
"""Vote for/against draw.""" """Vote for/against draw."""
name: Literal["vote"] = "vote" name: Literal["vote"] = "vote"
vote: Literal["yes", "no"] vote: Literal["yes", "no"]
class SaveGameRequest(GameRequest): class SaveGameRequest(GameRequest):
"""Export game as JSON.""" """Export game as JSON."""
name: Literal["save_game"] = "save_game" name: Literal["save_game"] = "save_game"
class SetGameStateRequest(GameRequest): class SetGameStateRequest(GameRequest):
"""Modify game state (master only).""" """Modify game state (master only)."""
name: Literal["set_game_state"] = "set_game_state" name: Literal["set_game_state"] = "set_game_state"
state: Dict[str, Any] state: Dict[str, Any]
class SetGameStatusRequest(GameRequest): class SetGameStatusRequest(GameRequest):
"""Change game status (master only).""" """Change game status (master only)."""
name: Literal["set_game_status"] = "set_game_status" name: Literal["set_game_status"] = "set_game_status"
status: str status: str
class SetDummyPowersRequest(GameRequest): class SetDummyPowersRequest(GameRequest):
"""Make powers AI-controlled (master only).""" """Make powers AI-controlled (master only)."""
name: Literal["set_dummy_powers"] = "set_dummy_powers" name: Literal["set_dummy_powers"] = "set_dummy_powers"
power_names: List[str] power_names: List[str]
class DeleteGameRequest(GameRequest): class DeleteGameRequest(GameRequest):
"""Delete game (master only).""" """Delete game (master only)."""
name: Literal["delete_game"] = "delete_game" name: Literal["delete_game"] = "delete_game"
class LeaveGameRequest(GameRequest): class LeaveGameRequest(GameRequest):
"""Leave game.""" """Leave game."""
name: Literal["leave_game"] = "leave_game" name: Literal["leave_game"] = "leave_game"
@ -226,13 +261,16 @@ class LeaveGameRequest(GameRequest):
# Response Messages (Server -> Client) # Response Messages (Server -> Client)
# ============================================================================= # =============================================================================
class OkResponse(BaseResponse): class OkResponse(BaseResponse):
"""Generic success response.""" """Generic success response."""
name: Literal["ok"] = "ok" name: Literal["ok"] = "ok"
class ErrorResponse(BaseResponse): class ErrorResponse(BaseResponse):
"""Error response with error type and message.""" """Error response with error type and message."""
name: Literal["error"] = "error" name: Literal["error"] = "error"
error_type: str error_type: str
message: str message: str
@ -240,60 +278,70 @@ class ErrorResponse(BaseResponse):
class DataTokenResponse(BaseResponse): class DataTokenResponse(BaseResponse):
"""Contains authentication token.""" """Contains authentication token."""
name: Literal["data_token"] = "data_token" name: Literal["data_token"] = "data_token"
data: str # The authentication token data: str # The authentication token
class DataGameResponse(BaseResponse): class DataGameResponse(BaseResponse):
"""Contains full game object.""" """Contains full game object."""
name: Literal["data_game"] = "data_game" name: Literal["data_game"] = "data_game"
data: Dict[str, Any] # The complete game state data: Dict[str, Any] # The complete game state
class DataGameInfoResponse(BaseResponse): class DataGameInfoResponse(BaseResponse):
"""Contains game metadata.""" """Contains game metadata."""
name: Literal["data_game_info"] = "data_game_info" name: Literal["data_game_info"] = "data_game_info"
data: Dict[str, Any] data: Dict[str, Any]
class DataGamesResponse(BaseResponse): class DataGamesResponse(BaseResponse):
"""List of game information.""" """List of game information."""
name: Literal["data_games"] = "data_games" name: Literal["data_games"] = "data_games"
data: List[Dict[str, Any]] data: List[Dict[str, Any]]
class DataMapsResponse(BaseResponse): class DataMapsResponse(BaseResponse):
"""Available maps information.""" """Available maps information."""
name: Literal["data_maps"] = "data_maps" name: Literal["data_maps"] = "data_maps"
data: List[str] data: List[str]
class DataPowerNamesResponse(BaseResponse): class DataPowerNamesResponse(BaseResponse):
"""List of power names.""" """List of power names."""
name: Literal["data_power_names"] = "data_power_names" name: Literal["data_power_names"] = "data_power_names"
data: List[str] data: List[str]
class DataPossibleOrdersResponse(BaseResponse): class DataPossibleOrdersResponse(BaseResponse):
"""Legal orders and locations.""" """Legal orders and locations."""
name: Literal["data_possible_orders"] = "data_possible_orders" name: Literal["data_possible_orders"] = "data_possible_orders"
data: Dict[str, List[str]] # Location -> list of possible orders data: Dict[str, List[str]] # Location -> list of possible orders
class DataGamePhasesResponse(BaseResponse): class DataGamePhasesResponse(BaseResponse):
"""Historical game phases.""" """Historical game phases."""
name: Literal["data_game_phases"] = "data_game_phases" name: Literal["data_game_phases"] = "data_game_phases"
data: List[Dict[str, Any]] data: List[Dict[str, Any]]
class DataSavedGameResponse(BaseResponse): class DataSavedGameResponse(BaseResponse):
"""Exported game JSON.""" """Exported game JSON."""
name: Literal["data_saved_game"] = "data_saved_game" name: Literal["data_saved_game"] = "data_saved_game"
data: Dict[str, Any] data: Dict[str, Any]
class DataPortResponse(BaseResponse): class DataPortResponse(BaseResponse):
"""DAIDE port number.""" """DAIDE port number."""
name: Literal["data_port"] = "data_port" name: Literal["data_port"] = "data_port"
data: int data: int
@ -302,8 +350,10 @@ class DataPortResponse(BaseResponse):
# Notification Messages (Server -> Client) # Notification Messages (Server -> Client)
# ============================================================================= # =============================================================================
class GameProcessedNotification(BaseNotification): class GameProcessedNotification(BaseNotification):
"""Phase completed, new orders phase.""" """Phase completed, new orders phase."""
name: Literal["game_processed"] = "game_processed" name: Literal["game_processed"] = "game_processed"
game_id: str game_id: str
phase: str phase: str
@ -312,6 +362,7 @@ class GameProcessedNotification(BaseNotification):
class GamePhaseUpdateNotification(BaseNotification): class GamePhaseUpdateNotification(BaseNotification):
"""Game state changed.""" """Game state changed."""
name: Literal["game_phase_update"] = "game_phase_update" name: Literal["game_phase_update"] = "game_phase_update"
game_id: str game_id: str
phase: str phase: str
@ -320,6 +371,7 @@ class GamePhaseUpdateNotification(BaseNotification):
class GameStatusUpdateNotification(BaseNotification): class GameStatusUpdateNotification(BaseNotification):
"""Game status changed (forming/active/paused/completed).""" """Game status changed (forming/active/paused/completed)."""
name: Literal["game_status_update"] = "game_status_update" name: Literal["game_status_update"] = "game_status_update"
game_id: str game_id: str
status: str status: str
@ -327,6 +379,7 @@ class GameStatusUpdateNotification(BaseNotification):
class PowersControllersNotification(BaseNotification): class PowersControllersNotification(BaseNotification):
"""Power control assignments changed.""" """Power control assignments changed."""
name: Literal["powers_controllers"] = "powers_controllers" name: Literal["powers_controllers"] = "powers_controllers"
game_id: str game_id: str
controllers: Dict[str, str] # Power -> Controller mapping controllers: Dict[str, str] # Power -> Controller mapping
@ -334,6 +387,7 @@ class PowersControllersNotification(BaseNotification):
class PowerOrdersUpdateNotification(BaseNotification): class PowerOrdersUpdateNotification(BaseNotification):
"""Player submitted new orders.""" """Player submitted new orders."""
name: Literal["power_orders_update"] = "power_orders_update" name: Literal["power_orders_update"] = "power_orders_update"
game_id: str game_id: str
power_name: str power_name: str
@ -343,6 +397,7 @@ class PowerOrdersUpdateNotification(BaseNotification):
class PowerOrdersFlagNotification(BaseNotification): class PowerOrdersFlagNotification(BaseNotification):
"""Player order submission status.""" """Player order submission status."""
name: Literal["power_orders_flag"] = "power_orders_flag" name: Literal["power_orders_flag"] = "power_orders_flag"
game_id: str game_id: str
power_name: str power_name: str
@ -352,6 +407,7 @@ class PowerOrdersFlagNotification(BaseNotification):
class PowerWaitFlagNotification(BaseNotification): class PowerWaitFlagNotification(BaseNotification):
"""Player wait flag changed.""" """Player wait flag changed."""
name: Literal["power_wait_flag"] = "power_wait_flag" name: Literal["power_wait_flag"] = "power_wait_flag"
game_id: str game_id: str
power_name: str power_name: str
@ -360,6 +416,7 @@ class PowerWaitFlagNotification(BaseNotification):
class GameMessageReceivedNotification(BaseNotification): class GameMessageReceivedNotification(BaseNotification):
"""Diplomatic message received.""" """Diplomatic message received."""
name: Literal["game_message_received"] = "game_message_received" name: Literal["game_message_received"] = "game_message_received"
game_id: str game_id: str
sender: str sender: str
@ -371,6 +428,7 @@ class GameMessageReceivedNotification(BaseNotification):
class VoteUpdatedNotification(BaseNotification): class VoteUpdatedNotification(BaseNotification):
"""Draw votes changed (omniscient view).""" """Draw votes changed (omniscient view)."""
name: Literal["vote_updated"] = "vote_updated" name: Literal["vote_updated"] = "vote_updated"
game_id: str game_id: str
votes: Dict[str, str] # Power -> vote mapping votes: Dict[str, str] # Power -> vote mapping
@ -378,6 +436,7 @@ class VoteUpdatedNotification(BaseNotification):
class VoteCountUpdatedNotification(BaseNotification): class VoteCountUpdatedNotification(BaseNotification):
"""Vote count changed (observer view).""" """Vote count changed (observer view)."""
name: Literal["vote_count_updated"] = "vote_count_updated" name: Literal["vote_count_updated"] = "vote_count_updated"
game_id: str game_id: str
count_yes: int count_yes: int
@ -386,6 +445,7 @@ class VoteCountUpdatedNotification(BaseNotification):
class PowerVoteUpdatedNotification(BaseNotification): class PowerVoteUpdatedNotification(BaseNotification):
"""Own power's vote changed.""" """Own power's vote changed."""
name: Literal["power_vote_updated"] = "power_vote_updated" name: Literal["power_vote_updated"] = "power_vote_updated"
game_id: str game_id: str
power_name: str power_name: str
@ -394,12 +454,14 @@ class PowerVoteUpdatedNotification(BaseNotification):
class GameDeletedNotification(BaseNotification): class GameDeletedNotification(BaseNotification):
"""Game removed from server.""" """Game removed from server."""
name: Literal["game_deleted"] = "game_deleted" name: Literal["game_deleted"] = "game_deleted"
game_id: str game_id: str
class OmniscientUpdatedNotification(BaseNotification): class OmniscientUpdatedNotification(BaseNotification):
"""Observer permissions changed.""" """Observer permissions changed."""
name: Literal["omniscient_updated"] = "omniscient_updated" name: Literal["omniscient_updated"] = "omniscient_updated"
game_id: str game_id: str
omniscient_type: str omniscient_type: str
@ -407,12 +469,14 @@ class OmniscientUpdatedNotification(BaseNotification):
class AccountDeletedNotification(BaseNotification): class AccountDeletedNotification(BaseNotification):
"""User account deleted.""" """User account deleted."""
name: Literal["account_deleted"] = "account_deleted" name: Literal["account_deleted"] = "account_deleted"
username: str username: str
class ClearedCentersNotification(BaseNotification): class ClearedCentersNotification(BaseNotification):
"""Supply centers cleared.""" """Supply centers cleared."""
name: Literal["cleared_centers"] = "cleared_centers" name: Literal["cleared_centers"] = "cleared_centers"
game_id: str game_id: str
power_name: str power_name: str
@ -420,6 +484,7 @@ class ClearedCentersNotification(BaseNotification):
class ClearedOrdersNotification(BaseNotification): class ClearedOrdersNotification(BaseNotification):
"""Orders cleared.""" """Orders cleared."""
name: Literal["cleared_orders"] = "cleared_orders" name: Literal["cleared_orders"] = "cleared_orders"
game_id: str game_id: str
power_name: str power_name: str
@ -428,6 +493,7 @@ class ClearedOrdersNotification(BaseNotification):
class ClearedUnitsNotification(BaseNotification): class ClearedUnitsNotification(BaseNotification):
"""Units cleared.""" """Units cleared."""
name: Literal["cleared_units"] = "cleared_units" name: Literal["cleared_units"] = "cleared_units"
game_id: str game_id: str
power_name: str power_name: str
@ -514,6 +580,7 @@ WebSocketMessage = Union[RequestMessage, ResponseMessage, NotificationMessage]
# Utility Functions # Utility Functions
# ============================================================================= # =============================================================================
def parse_message(data: Dict[str, Any]) -> WebSocketMessage: def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
""" """
Parse a raw WebSocket message dictionary into the appropriate pydantic model. Parse a raw WebSocket message dictionary into the appropriate pydantic model.

View file

@ -50,13 +50,13 @@ class MultiBotLauncher:
# Default power to model mapping # Default power to model mapping
self.default_models = { self.default_models = {
"AUSTRIA": "gpt-3.5-turbo", "AUSTRIA": "gemini-2.5-flash-lite-preview-06-17",
"ENGLAND": "gpt-4", "ENGLAND": "gemini-2.5-flash-lite-preview-06-17",
"FRANCE": "claude-3-haiku", "FRANCE": "gemini-2.5-flash-lite-preview-06-17",
"GERMANY": "gpt-3.5-turbo", "GERMANY": "gemini-2.5-flash-lite-preview-06-17",
"ITALY": "gemini-pro", "ITALY": "gemini-2.5-flash-lite-preview-06-17",
"RUSSIA": "gpt-4", "RUSSIA": "gemini-2.5-flash-lite-preview-06-17",
"TURKEY": "claude-3-sonnet", "TURKEY": "gemini-2.5-flash-lite-preview-06-17",
} }
async def create_game(self, creator_power: str = "FRANCE") -> str: async def create_game(self, creator_power: str = "FRANCE") -> str:
@ -99,7 +99,12 @@ class MultiBotLauncher:
return game_id return game_id
def launch_bot( def launch_bot(
self, power: str, model: str, game_id: str, log_level: str = "INFO" self,
power: str,
model: str,
game_id: str,
log_level: str = "INFO",
negotiation_rounds: int = 3,
) -> subprocess.Popen: ) -> subprocess.Popen:
""" """
Launch a single bot process. Launch a single bot process.
@ -134,6 +139,8 @@ class MultiBotLauncher:
game_id, game_id,
"--log-level", "--log-level",
log_level, log_level,
"--negotiation-rounds",
str(negotiation_rounds),
] ]
logger.info(f"Launching bot for {power} with model {model}") logger.info(f"Launching bot for {power} with model {model}")
@ -156,7 +163,8 @@ class MultiBotLauncher:
models: Optional[Dict[str, str]] = None, models: Optional[Dict[str, str]] = None,
powers: Optional[List[str]] = None, powers: Optional[List[str]] = None,
log_level: str = "INFO", log_level: str = "INFO",
stagger_delay: float = 2.0, stagger_delay: float = 0.5,
negotiation_rounds: int = 3,
): ):
""" """
Launch bots for all specified powers. Launch bots for all specified powers.
@ -180,13 +188,11 @@ class MultiBotLauncher:
model = models.get(power, "gpt-3.5-turbo") model = models.get(power, "gpt-3.5-turbo")
try: try:
process = self.launch_bot(power, model, game_id, log_level) process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
self.bot_processes.append(process) self.bot_processes.append(process)
self.process_to_power[process] = power self.process_to_power[process] = power
logger.info( logger.info(f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})")
f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})"
)
# Stagger the launches to avoid overwhelming the server # Stagger the launches to avoid overwhelming the server
if i < len(powers) - 1: # Don't delay after the last bot if i < len(powers) - 1: # Don't delay after the last bot
@ -224,9 +230,7 @@ class MultiBotLauncher:
# Process has ended # Process has ended
return_code = process.returncode return_code = process.returncode
power = self.process_to_power.get(process, "UNKNOWN") power = self.process_to_power.get(process, "UNKNOWN")
logger.info( logger.info(f"{power} bot process {process.pid} ended with code {return_code}")
f"{power} bot process {process.pid} ended with code {return_code}"
)
# Read any remaining output # Read any remaining output
remaining_output = process.stdout.read() remaining_output = process.stdout.read()
@ -243,7 +247,7 @@ class MultiBotLauncher:
break break
# Use select to check which processes have output ready (Unix only) # Use select to check which processes have output ready (Unix only)
if stdout_fds and hasattr(select, 'select'): if stdout_fds and hasattr(select, "select"):
try: try:
ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout
@ -324,6 +328,7 @@ class MultiBotLauncher:
models: Optional[Dict[str, str]] = None, models: Optional[Dict[str, str]] = None,
log_level: str = "INFO", log_level: str = "INFO",
creator_power: str = "FRANCE", creator_power: str = "FRANCE",
negotiation_rounds: int = 3,
): ):
""" """
Create a game and launch all bots for a complete game. Create a game and launch all bots for a complete game.
@ -342,7 +347,12 @@ class MultiBotLauncher:
await asyncio.sleep(2) await asyncio.sleep(2)
# Launch all bots # Launch all bots
await self.launch_all_bots(game_id, models, log_level=log_level) await self.launch_all_bots(
game_id,
models,
log_level=log_level,
negotiation_rounds=negotiation_rounds,
)
# Monitor the bots # Monitor the bots
self.monitor_bots() self.monitor_bots()
@ -358,6 +368,7 @@ class MultiBotLauncher:
powers: List[str], powers: List[str],
models: Optional[Dict[str, str]] = None, models: Optional[Dict[str, str]] = None,
log_level: str = "INFO", log_level: str = "INFO",
negotiation_rounds: int = 3,
): ):
""" """
Launch bots to join an existing game. Launch bots to join an existing game.
@ -372,7 +383,13 @@ class MultiBotLauncher:
self.game_id = game_id self.game_id = game_id
# Launch bots for specified powers # Launch bots for specified powers
await self.launch_all_bots(game_id, models, powers, log_level) await self.launch_all_bots(
game_id,
models,
powers,
log_level,
negotiation_rounds=negotiation_rounds,
)
# Monitor the bots # Monitor the bots
self.monitor_bots() self.monitor_bots()
@ -391,18 +408,16 @@ def parse_arguments():
parser.add_argument("--port", type=int, default=8432, help="Server port") parser.add_argument("--port", type=int, default=8432, help="Server port")
parser.add_argument("--username-base", default="bot", help="Base username for bots") parser.add_argument("--username-base", default="bot", help="Base username for bots")
parser.add_argument("--password", default="password", help="Password for all bots") parser.add_argument("--password", default="password", help="Password for all bots")
parser.add_argument( parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
"--game-id", help="Game ID to join (creates new if not specified)" parser.add_argument("--powers", nargs="+", help="Powers to launch bots for (default: all)")
) parser.add_argument("--models", help="Comma-separated list of models in power order")
parser.add_argument(
"--powers", nargs="+", help="Powers to launch bots for (default: all)"
)
parser.add_argument(
"--models", help="Comma-separated list of models in power order"
)
parser.add_argument("--log-level", default="INFO", help="Logging level") parser.add_argument("--log-level", default="INFO", help="Logging level")
parser.add_argument("--creator-power", default="FRANCE", help="Power that creates the game")
parser.add_argument( parser.add_argument(
"--creator-power", default="FRANCE", help="Power that creates the game" "--negotiation-rounds",
type=int,
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
) )
return parser.parse_args() return parser.parse_args()
@ -425,9 +440,7 @@ async def main():
model_list = [m.strip() for m in args.models.split(",")] model_list = [m.strip() for m in args.models.split(",")]
powers = args.powers or list(launcher.default_models.keys()) powers = args.powers or list(launcher.default_models.keys())
if len(model_list) != len(powers): if len(model_list) != len(powers):
logger.error( logger.error(f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})")
f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})"
)
return return
models = dict(zip(powers, model_list)) models = dict(zip(powers, model_list))
@ -440,6 +453,7 @@ async def main():
powers=powers, powers=powers,
models=models, models=models,
log_level=args.log_level, log_level=args.log_level,
negotiation_rounds=args.negotiation_rounds,
) )
else: else:
# Create new game and launch all bots # Create new game and launch all bots
@ -447,6 +461,7 @@ async def main():
models=models, models=models,
log_level=args.log_level, log_level=args.log_level,
creator_power=args.creator_power, creator_power=args.creator_power,
negotiation_rounds=args.negotiation_rounds,
) )
except KeyboardInterrupt: except KeyboardInterrupt:

View file

@ -29,6 +29,7 @@ include = [".", ".."]
[tool.ruff] [tool.ruff]
src = ["..", "."] src = ["..", "."]
line-length = 150
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["bot_client/tests"] testpaths = ["bot_client/tests"]
@ -44,8 +45,4 @@ markers = [
] ]
[dependency-groups] [dependency-groups]
dev = [ dev = ["pytest>=8.4.1", "pytest-asyncio>=0.21.0", "websockets>=15.0.1"]
"pytest>=8.4.1",
"pytest-asyncio>=0.21.0",
"websockets>=15.0.1",
]

View file

@ -13,25 +13,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import argparse import argparse
import asyncio import asyncio
import signal import signal
from typing import Optional, Dict from typing import Optional, Dict, List
import dotenv import dotenv
from loguru import logger from loguru import logger
from websocket_diplomacy_client import ( from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
WebSocketDiplomacyClient,
connect_to_diplomacy_server,
)
from diplomacy.utils.exceptions import DiplomacyException, GameIdException
# Suppress warnings
# os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40"
# os.environ["GRPC_VERBOSITY"] = "ERROR"
# os.environ["ABSL_MIN_LOG_LEVEL"] = "2"
# os.environ["GRPC_POLL_STRATEGY"] = "poll"
from diplomacy.engine.power import Power
from diplomacy.utils.exceptions import GameIdException
from diplomacy.communication.notifications import GameStatusUpdate
from diplomacy.engine.message import Message from diplomacy.engine.message import Message
from ai_diplomacy.clients import load_model_client from ai_diplomacy.clients import load_model_client
@ -40,15 +32,16 @@ from ai_diplomacy.game_history import GameHistory
from ai_diplomacy.agent import DiplomacyAgent from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.initialization import initialize_agent_state_ext from ai_diplomacy.initialization import initialize_agent_state_ext
from config import Configuration from config import Configuration
from websocket_negotiations import (
conduct_strategic_negotiation_round,
should_participate_in_negotiations,
get_negotiation_delay,
)
dotenv.load_dotenv() dotenv.load_dotenv()
config = Configuration() # TODO: This, but better
config: Configuration = None
if config.DEBUG:
import tracemalloc
tracemalloc.start()
class SingleBotPlayer: class SingleBotPlayer:
@ -70,6 +63,7 @@ class SingleBotPlayer:
hostname: str = "localhost", hostname: str = "localhost",
port: int = 8432, port: int = 8432,
game_id: Optional[str] = None, game_id: Optional[str] = None,
negotiation_rounds: int = 3,
): ):
assert username is not None assert username is not None
assert password is not None assert password is not None
@ -87,16 +81,24 @@ class SingleBotPlayer:
# Bot state # Bot state
self.client: WebSocketDiplomacyClient self.client: WebSocketDiplomacyClient
self.agent: DiplomacyAgent self.agent: DiplomacyAgent
self.game_history = GameHistory() self.game_history: GameHistory = GameHistory()
self.running = True self.running = True
self.current_phase = None self.current_phase = None
self.waiting_for_orders = False self.waiting_for_orders = False
self.orders_submitted = False self.orders_submitted = False
# Negotiation settings
self.negotiation_rounds = negotiation_rounds
self.current_negotiation_round = 0
self.negotiation_complete = False
# Track error stats # Track error stats
self.error_stats: Dict[str, Dict[str, int]] = { self.error_stats: Dict[str, Dict[str, int]] = {self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}}
self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}
} # Track messaging patterns for strategic communication
self.message_counts: Dict[str, int] = {} # Messages received from each power
self.response_counts: Dict[str, int] = {} # Responses sent to each power
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
# Setup signal handlers for graceful shutdown # Setup signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler) signal.signal(signal.SIGINT, self._signal_handler)
@ -122,9 +124,7 @@ class SingleBotPlayer:
# Join or create game # Join or create game
if self.game_id: if self.game_id:
logger.info(f"Joining existing game {self.game_id} as {self.power_name}") logger.info(f"Joining existing game {self.game_id} as {self.power_name}")
game = await self.client.join_game( await self.client.join_game(game_id=self.game_id, power_name=self.power_name)
game_id=self.game_id, power_name=self.power_name
)
else: else:
logger.info(f"Creating new game as {self.power_name}") logger.info(f"Creating new game as {self.power_name}")
await self.client.create_game( await self.client.create_game(
@ -142,9 +142,7 @@ class SingleBotPlayer:
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client) self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
# Initialize agent state # Initialize agent state
await initialize_agent_state_ext( await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
self.agent, self.client.game, self.game_history, config.log_file_path
)
# Setup game event callbacks # Setup game event callbacks
await self._setup_event_callbacks() await self._setup_event_callbacks()
@ -198,6 +196,16 @@ class SingleBotPlayer:
self.current_phase = new_phase self.current_phase = new_phase
self.game_history.add_phase(new_phase) self.game_history.add_phase(new_phase)
self.orders_submitted = False self.orders_submitted = False
self.current_negotiation_round = 0
self.negotiation_complete = False
# Log message statistics at phase transitions
if hasattr(self, "message_counts") and self.message_counts:
self.log_message_statistics()
# Check if we should start negotiations for movement phases (not adjustment phases)
if new_phase.endswith("M"):
await self._handle_negotiation_phase()
# Check if we need to submit orders for this new phase # Check if we need to submit orders for this new phase
await self._check_if_orders_needed() await self._check_if_orders_needed()
@ -223,9 +231,7 @@ class SingleBotPlayer:
def _on_message_received(self, game, notification): def _on_message_received(self, game, notification):
"""Handle incoming diplomatic messages.""" """Handle incoming diplomatic messages."""
message = notification.message message = notification.message
logger.info( logger.info(f"Message received from {message.sender} to {message.recipient}: {message.message}")
f"Message received from {message.sender} to {message.recipient}: {message.message}"
)
# Add message to game history # Add message to game history
self.game_history.add_message( self.game_history.add_message(
@ -235,12 +241,17 @@ class SingleBotPlayer:
message_content=message.message, message_content=message.message,
) )
# Track message patterns
if message.recipient == self.power_name and message.sender != self.power_name:
self.message_counts[message.sender] = self.message_counts.get(message.sender, 0) + 1
self._update_priority_contacts()
# If it's a private message to us, consider responding # If it's a private message to us, consider responding
if message.recipient == self.power_name and message.sender != self.power_name: if message.recipient == self.power_name and message.sender != self.power_name:
# Schedule the async processing in the event loop # Schedule the async processing in the event loop
asyncio.create_task(self._consider_message_response(message)) asyncio.create_task(self._consider_message_response(message))
def _on_status_update(self, game, notification): def _on_status_update(self, game, notification: GameStatusUpdate):
"""Handle game status changes.""" """Handle game status changes."""
logger.info(f"Game status updated: {notification.status}") logger.info(f"Game status updated: {notification.status}")
@ -261,20 +272,50 @@ class SingleBotPlayer:
# Check if it's a phase where we can submit orders # Check if it's a phase where we can submit orders
current_short_phase = self.client.game.current_short_phase current_short_phase = self.client.game.current_short_phase
# We submit orders in Movement and Retreat phases logger.debug(f"Checking if orders needed for phase: {current_short_phase}")
if current_short_phase.endswith("M") or current_short_phase.endswith("R"):
# Check if we have units that can receive orders # Movement and Retreat phases
orderable_locations = self.client.game.get_orderable_locations( orderable_locations = self.client.game.get_orderable_locations(self.power_name)
self.power_name
)
if orderable_locations: if orderable_locations:
logger.info(f"Orders needed for phase {current_short_phase}") logger.info(f"Orders needed for {current_short_phase} phase - orderable locations: {orderable_locations}")
self.waiting_for_orders = True self.waiting_for_orders = True
await self._submit_orders() await self._submit_orders()
else: else:
logger.info( logger.info(f"No orderable locations for {self.power_name} in {current_short_phase}")
f"No orderable locations for {self.power_name} in {current_short_phase}"
) async def _submit_adjustment_orders(self, action_type, count):
"""Submit build or disband orders for adjustment phase."""
# FIXME: This whole function is horse shit.
return
if action_type == "build":
# Get buildable locations for this power
possible_orders = gather_possible_order(self.client.game, self.power_name)
if not possible_orders:
logger.warning(f"No possible build orders for {self.power_name}")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
return
# Filter for build orders (usually start with unit type + location)
build_orders = [order for order in possible_orders if " - " not in order and any(order.startswith(unit) for unit in ["A ", "F "])]
logger.info(f"Available build orders for {self.power_name}: {build_orders}")
# Select up to 'count' build orders
selected_orders = build_orders[:count]
logger.info(f"Submitting build orders for {self.power_name}: {selected_orders}")
await self.client.set_orders(self.power_name, selected_orders)
elif action_type == "disband":
# Get current units for disbanding
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
await self.client.set_orders(self.power_name, disband_orders)
self.orders_submitted = True
self.waiting_for_orders = False
logger.info(f"Adjustment orders submitted successfully for {self.power_name}")
async def _submit_orders(self): async def _submit_orders(self):
"""Generate and submit orders for the current phase.""" """Generate and submit orders for the current phase."""
@ -282,8 +323,8 @@ class SingleBotPlayer:
logger.debug("Orders already submitted for this phase") logger.debug("Orders already submitted for this phase")
return return
try: current_phase = self.client.game.get_current_phase()
logger.info("Generating orders...") logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
# Get current board state # Get current board state
board_state = self.client.game.get_state() board_state = self.client.game.get_state()
@ -291,8 +332,10 @@ class SingleBotPlayer:
# Get possible orders # Get possible orders
possible_orders = gather_possible_orders(self.client.game, self.power_name) possible_orders = gather_possible_orders(self.client.game, self.power_name)
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
if not possible_orders: if not possible_orders:
logger.info("No possible orders, submitting empty order set") logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
await self.client.set_orders(self.power_name, []) await self.client.set_orders(self.power_name, [])
self.orders_submitted = True self.orders_submitted = True
return return
@ -331,22 +374,12 @@ class SingleBotPlayer:
self.waiting_for_orders = False self.waiting_for_orders = False
logger.info("Orders submitted successfully") logger.info("Orders submitted successfully")
# Call the no wait so we don't sit around for the turns to end. # Call the no wait so we don't sit around for the turns to end.
# TODO: We probably don't want to call this here.
# We want to call it when negotiations end,
self.client.game.no_wait() self.client.game.no_wait()
except DiplomacyException as e:
logger.error(f"Error submitting orders: {e}", exc_info=True)
# FIXME: I don't think we want to do this. Likely want to retry again multiple times.
#
# Submit empty orders as fallback
try:
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
except Exception as fallback_error:
logger.error(f"Failed to submit fallback orders: {fallback_error}")
async def _analyze_phase_results(self): async def _analyze_phase_results(self):
"""Analyze the results of the previous phase.""" """Analyze the results of the previous phase."""
try:
logger.info("Analyzing phase results...") logger.info("Analyzing phase results...")
# Get current board state after processing # Get current board state after processing
@ -366,24 +399,204 @@ class SingleBotPlayer:
logger.info("Phase analysis complete") logger.info("Phase analysis complete")
except Exception as e: async def _handle_negotiation_phase(self):
logger.error(f"Error analyzing phase results: {e}", exc_info=True) """Handle the negotiation phase for movement turns."""
# Check if we should participate in negotiations
if not await should_participate_in_negotiations(self.client, self.agent):
logger.info(f"{self.power_name} will not participate in negotiations this phase")
self.negotiation_complete = True
return
logger.info(f"Starting negotiation phase for {self.power_name}")
# TODO: This doesn't need a specific number of negotiation rounds, though it should have a top number of messages this turn so they don't blabber on forever.
# Conduct negotiations for the specified number of rounds
for round_num in range(1, self.negotiation_rounds + 1):
self.current_negotiation_round = round_num
logger.info(f"Negotiation round {round_num}/{self.negotiation_rounds} for {self.power_name}")
# Use strategic negotiation that analyzes recent messages
success = await conduct_strategic_negotiation_round(
client=self.client,
agent=self.agent,
game_history=self.game_history,
model_error_stats=self.error_stats,
log_file_path=config.log_file_path,
round_number=round_num,
max_rounds=self.negotiation_rounds,
)
if not success:
logger.info(f"No messages sent in round {round_num} for {self.power_name}")
# Wait between rounds to allow other bots to respond
if round_num < self.negotiation_rounds:
delay = get_negotiation_delay(round_num, self.negotiation_rounds)
logger.debug(f"Waiting {delay}s before next negotiation round")
await asyncio.sleep(delay)
self.negotiation_complete = True
logger.info(f"Negotiation phase complete for {self.power_name}")
async def _consider_message_response(self, message: Message): async def _consider_message_response(self, message: Message):
"""Consider whether to respond to a diplomatic message.""" """Consider whether to respond to a diplomatic message."""
try: try:
# Simple logic: if someone greets us, greet back # Only respond to messages directed at us specifically
if any( if message.recipient != self.power_name:
word in message.message.lower() for word in ["hello", "hi", "greetings"] return
):
response = f"Hello {message.sender}! Good to hear from you." # Don't respond to our own messages
await self.client.game.send_game_message( if message.sender == self.power_name:
sender=self.power_name, recipient=message.sender, message=response return
logger.info(f"Considering response to message from {message.sender}: {message.message[:50]}...")
# Enhanced heuristic: respond to direct questions, proposals, and strategic keywords
message_lower = message.message.lower()
strategic_keywords = [
"alliance",
"deal",
"propose",
"agreement",
"support",
"attack",
"coordinate",
"move",
"order",
"help",
"work together",
"partner",
"enemy",
"threat",
"negotiate",
"discuss",
"plan",
"strategy",
"bounce",
"convoy",
"retreat",
]
should_respond = any(
[
"?" in message.message, # Questions
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
len(message.message.split()) > 15, # Longer messages suggest they want engagement
message.sender in self.priority_contacts, # Priority contacts
]
) )
logger.info(f"Sent response to {message.sender}: {response}")
if should_respond:
# Generate a contextual response using AI
# Get current game state for context
board_state = self.client.get_state()
possible_orders = gather_possible_orders(self.client.game, self.power_name)
# Create a simple conversation context
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated()]
# Generate response using the agent's conversation capabilities
responses = await self.agent.client.get_conversation_reply(
game=self.client.game,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
game_phase=self.client.get_current_short_phase(),
log_file_path=config.log_file_path,
active_powers=active_powers,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
)
# Send the first response if any were generated
if responses and len(responses) > 0:
response_content = responses[0].get("content", "").strip()
if response_content:
await self.client.send_message(
sender=self.power_name,
recipient=message.sender,
message=response_content,
phase=self.client.get_current_short_phase(),
)
# Add to game history
self.game_history.add_message(
phase_name=self.client.get_current_short_phase(),
sender=self.power_name,
recipient=message.sender,
message_content=response_content,
)
# Track response patterns
self.response_counts[message.sender] = self.response_counts.get(message.sender, 0) + 1
# Add to agent's journal
self.agent.add_journal_entry(
f"Responded to {message.sender} in {self.client.get_current_short_phase()}: {response_content[:100]}..."
)
logger.info(f"Sent AI response to {message.sender}: {response_content[:50]}...")
else:
logger.debug(f"AI generated empty response to {message.sender}")
else:
logger.debug(f"AI generated no responses to {message.sender}")
else:
logger.debug(f"Decided not to respond to message from {message.sender}")
except Exception as e: except Exception as e:
logger.error(f"Error responding to message: {e}") logger.error(f"Error responding to message: {e}", exc_info=True)
def _update_priority_contacts(self) -> None:
"""Update the list of priority contacts based on messaging patterns."""
# Sort powers by message count (descending) and take top 3-4
sorted_contacts = sorted(self.message_counts.items(), key=lambda x: x[1], reverse=True)
# Keep top 4 most active contacts as priority
self.priority_contacts = [contact[0] for contact in sorted_contacts[:4]]
logger.debug(f"Updated priority contacts for {self.power_name}: {self.priority_contacts}")
def get_message_statistics(self) -> Dict[str, any]:
"""Get comprehensive statistics about messaging patterns."""
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated() and p_name != self.power_name]
stats = {
"power_name": self.power_name,
"total_messages_received": sum(self.message_counts.values()),
"total_responses_sent": sum(self.response_counts.values()),
"message_counts_by_power": dict(self.message_counts),
"response_counts_by_power": dict(self.response_counts),
"priority_contacts": list(self.priority_contacts),
"response_rate_by_power": {},
"active_powers": active_powers,
"current_phase": self.current_phase,
}
# Calculate response rates
for power, received in self.message_counts.items():
sent = self.response_counts.get(power, 0)
stats["response_rate_by_power"][power] = sent / received if received > 0 else 0.0
return stats
def log_message_statistics(self) -> None:
"""Log current message statistics for analysis."""
stats = self.get_message_statistics()
logger.info(f"Message Statistics for {self.power_name}:")
logger.info(f" Total messages received: {stats['total_messages_received']}")
logger.info(f" Total responses sent: {stats['total_responses_sent']}")
logger.info(f" Priority contacts: {stats['priority_contacts']}")
for power in stats["active_powers"]:
received = stats["message_counts_by_power"].get(power, 0)
sent = stats["response_counts_by_power"].get(power, 0)
rate = stats["response_rate_by_power"].get(power, 0.0)
logger.info(f" {power}: {received} received, {sent} sent, {rate:.1%} response rate")
async def run(self): async def run(self):
"""Main bot loop.""" """Main bot loop."""
@ -408,9 +621,7 @@ class SingleBotPlayer:
else: else:
logger.info("Bot shutting down") logger.info("Bot shutting down")
except GameIdException: except GameIdException:
logger.error( logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
f"Game with id {self.game_id} does not exist on the server. Exiting..."
)
finally: finally:
await self.cleanup() await self.cleanup()
@ -437,10 +648,14 @@ def parse_arguments():
parser.add_argument("--password", default="password", help="Bot password") parser.add_argument("--password", default="password", help="Bot password")
parser.add_argument("--power", default="FRANCE", help="Power to control") parser.add_argument("--power", default="FRANCE", help="Power to control")
parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use") parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use")
parser.add_argument( parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
"--game-id", help="Game ID to join (creates new if not specified)"
)
parser.add_argument("--log-level", default="INFO", help="Logging level") parser.add_argument("--log-level", default="INFO", help="Logging level")
parser.add_argument(
"--negotiation-rounds",
type=int,
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
)
return parser.parse_args() return parser.parse_args()
@ -459,8 +674,9 @@ async def main():
power_name=args.power, power_name=args.power,
model_name=args.model, model_name=args.model,
game_id=args.game_id, game_id=args.game_id,
negotiation_rounds=args.negotiation_rounds,
) )
config = Configuration(args.power)
await bot.run() await bot.run()

View file

@ -27,9 +27,7 @@ class WebSocketDiplomacyClient:
but communicates with a remote server via WebSocket connections. but communicates with a remote server via WebSocket connections.
""" """
def __init__( def __init__(self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False):
self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False
):
""" """
Initialize the WebSocket client. Initialize the WebSocket client.
@ -199,9 +197,7 @@ class WebSocketDiplomacyClient:
return await self.channel.get_available_maps() return await self.channel.get_available_maps()
async def set_orders( async def set_orders(self, power_name: str, orders: List[str], wait: Optional[bool] = None) -> None:
self, power_name: str, orders: List[str], wait: Optional[bool] = None
) -> None:
""" """
Set orders for a power. Set orders for a power.
@ -246,9 +242,7 @@ class WebSocketDiplomacyClient:
else: else:
await self.game.no_wait() await self.game.no_wait()
async def send_message( async def send_message(self, sender: str, recipient: str, message: str, phase: Optional[str] = None) -> None:
self, sender: str, recipient: str, message: str, phase: Optional[str] = None
) -> None:
""" """
Send a diplomatic message. Send a diplomatic message.
@ -269,6 +263,57 @@ class WebSocketDiplomacyClient:
logger.debug(f"Sending message from {sender} to {recipient}: {message}") logger.debug(f"Sending message from {sender} to {recipient}: {message}")
await self.game.send_game_message(message=msg) await self.game.send_game_message(message=msg)
async def send_broadcast_message(self, sender: str, message: str, phase: Optional[str] = None) -> None:
"""
Send a broadcast message to all active powers.
Args:
sender: Sending power name
message: Message content
phase: Game phase (uses current phase if None)
"""
if not self.game:
raise DiplomacyException("Must join a game first")
if phase is None:
phase = self.game.current_short_phase
# Send to all active powers
active_powers = [power_name for power_name, power in self.powers.items() if not power.is_eliminated() and power_name != sender]
for recipient in active_powers:
await self.send_message(sender, recipient, message, phase)
logger.debug(f"Broadcast message from {sender} to {len(active_powers)} powers: {message}")
async def get_recent_messages(self, phase: Optional[str] = None, limit: int = 50) -> List[Message]:
"""
Get recent messages from the game.
Args:
phase: Filter by specific phase (None for current phase)
limit: Maximum number of messages to return
Returns:
List of recent messages
"""
if not self.game:
raise DiplomacyException("Must join a game first")
if phase is None:
phase = self.game.current_short_phase
# Get messages from the game and filter by phase
all_messages = list(self.messages.values())
# Filter by phase if specified
filtered_messages = [msg for msg in all_messages if msg.phase == phase]
# Sort by time (most recent first) and limit
filtered_messages.sort(key=lambda m: m.time_sent or 0, reverse=True)
return filtered_messages[:limit]
async def process_game(self) -> None: async def process_game(self) -> None:
""" """
Force the game to process immediately (admin/moderator only). Force the game to process immediately (admin/moderator only).
@ -289,9 +334,7 @@ class WebSocketDiplomacyClient:
logger.debug("Synchronizing game state") logger.debug("Synchronizing game state")
await self.game.synchronize() await self.game.synchronize()
async def get_phase_history( async def get_phase_history(self, from_phase: Optional[str] = None, to_phase: Optional[str] = None) -> List[Dict[str, Any]]:
self, from_phase: Optional[str] = None, to_phase: Optional[str] = None
) -> List[Dict[str, Any]]:
""" """
Get phase history for the game. Get phase history for the game.
@ -305,9 +348,7 @@ class WebSocketDiplomacyClient:
if not self.game: if not self.game:
raise DiplomacyException("Must join a game first") raise DiplomacyException("Must join a game first")
return await self.game.get_phase_history( return await self.game.get_phase_history(from_phase=from_phase, to_phase=to_phase)
from_phase=from_phase, to_phase=to_phase
)
async def vote(self, power_name: str, vote: str) -> None: async def vote(self, power_name: str, vote: str) -> None:
""" """

View file

@ -0,0 +1,326 @@
"""
WebSocket-specific negotiation logic for single bot players.
This module provides negotiation capabilities for individual bots connected
via WebSocket, adapted from the multi-agent negotiation system in the main
ai_diplomacy package.
"""
from typing import Dict, Optional, List
from loguru import logger
from diplomacy.engine.message import GLOBAL
from ai_diplomacy.utils import gather_possible_orders
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.game_history import GameHistory
from websocket_diplomacy_client import WebSocketDiplomacyClient
async def conduct_single_bot_negotiation(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
max_rounds: int = 3,
round_number: int = 1,
prioritize_targets: Optional[List[str]] = None,
) -> bool:
"""
Conduct negotiation for a single bot during one negotiation round.
This function handles message generation and sending for one bot during
a negotiation phase. Unlike the multi-agent version, this focuses on
a single power and sends messages via WebSocket.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
model_error_stats: Error statistics tracking
log_file_path: Path for logging
max_rounds: Maximum number of negotiation rounds
round_number: Current round number (1-indexed)
prioritize_targets: Optional list of powers to prioritize for messaging
Returns:
True if messages were sent successfully, False otherwise
"""
power_name = agent.power_name
logger.info(f"Starting negotiation round {round_number}/{max_rounds} for {power_name}")
# Check if this power is eliminated
if client.get_power(power_name).is_eliminated():
logger.info(f"{power_name} is eliminated, skipping negotiation")
return False
# Check if this power has any orderable locations
possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders:
logger.info(f"No orderable locations for {power_name}, skipping negotiation")
return False
# Get active powers for context
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated()]
# Prioritize message targets if specified
message_targets = prioritize_targets if prioritize_targets else active_powers
message_targets = [p for p in message_targets if p in active_powers and p != power_name]
# Generate conversation messages using the AI agent
board_state = client.get_state()
messages = await agent.client.get_conversation_reply(
game=client.game,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
game_history=game_history,
game_phase=client.get_current_short_phase(),
log_file_path=log_file_path,
active_powers=active_powers,
agent_goals=agent.goals,
agent_relationships=agent.relationships,
agent_private_diary_str=agent.format_private_diary_for_prompt(),
)
if not messages:
logger.debug(f"No messages generated for {power_name} in round {round_number}")
return False
# Process and send each message
messages_sent = 0
for message in messages:
success = await _send_negotiation_message(client, agent, game_history, message, power_name)
if success:
messages_sent += 1
logger.info(f"Sent {messages_sent}/{len(messages)} messages for {power_name}")
return messages_sent > 0
async def _send_negotiation_message(
client: WebSocketDiplomacyClient,
agent: DiplomacyAgent,
game_history: GameHistory,
message: Dict,
power_name: str,
) -> bool:
"""
Send a single negotiation message via WebSocket.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
message: Message dictionary with content and metadata
power_name: Name of the sending power
Returns:
True if message was sent successfully, False otherwise
"""
# Validate message structure
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Invalid message format from {power_name}: {message}")
return False
content = message.get("content", "").strip()
if not content:
logger.debug(f"Empty message content from {power_name}, skipping")
return False
# Determine recipient
recipient = GLOBAL # Default to global
if message.get("message_type") == "private":
recipient = message.get("recipient", GLOBAL)
# Validate recipient is a valid power
if recipient not in client.powers and recipient != GLOBAL:
logger.warning(f"Invalid recipient '{recipient}' from {power_name}, sending globally")
recipient = GLOBAL
# Send the message via WebSocket
await client.send_message(
sender=power_name,
recipient=recipient,
message=content,
phase=client.get_current_short_phase(),
)
# Add to game history
game_history.add_message(
phase_name=client.get_current_short_phase(),
sender=power_name,
recipient=recipient,
message_content=content,
)
# Add to agent's journal
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
agent.add_journal_entry(f"Sent message {journal_recipient} in {client.get_current_short_phase()}: {content[:100]}...")
logger.info(f"[{power_name} -> {recipient}] {content[:100]}...")
return True
async def should_participate_in_negotiations(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
) -> bool:
"""
Determine if this bot should participate in negotiations.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
Returns:
True if the bot should participate in negotiations
"""
power_name = agent.power_name
# Don't negotiate if eliminated
if client.get_power(power_name).is_eliminated():
return False
# Don't negotiate if no orderable locations
possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders:
return False
# Only negotiate during movement phases
current_phase = client.get_current_short_phase()
if not current_phase.endswith("M"):
return False
return True
def get_negotiation_delay(round_number: int, total_rounds: int) -> float:
"""
Calculate delay between negotiation rounds to allow message processing.
Args:
round_number: Current round number (1-indexed)
total_rounds: Total number of rounds
Returns:
Delay in seconds
"""
# Longer delay in early rounds to allow more strategic messaging
base_delay = 10.0 # Base delay between rounds
if round_number == 1:
return base_delay * 1.5 # Extra time for first round
elif round_number == total_rounds:
return base_delay * 0.5 # Less time for final round
else:
return base_delay
async def analyze_recent_messages_for_targeting(
client: "WebSocketDiplomacyClient",
power_name: str,
max_messages: int = 20,
) -> List[str]:
"""
Analyze recent messages to identify which powers should be prioritized for negotiations.
Args:
client: WebSocket diplomacy client
power_name: Name of the analyzing power
max_messages: Maximum number of recent messages to analyze
Returns:
List of power names in order of priority for messaging
"""
# Get recent messages from current phase
recent_messages = await client.get_recent_messages(limit=max_messages)
# Track who has been active and who has messaged us
message_activity = {}
direct_messages_to_us = {}
for message in recent_messages:
sender = message.sender
recipient = message.recipient
# Track general activity
if sender != power_name:
message_activity[sender] = message_activity.get(sender, 0) + 1
# Track direct messages to us
if recipient == power_name and sender != power_name:
direct_messages_to_us[sender] = direct_messages_to_us.get(sender, 0) + 1
# Get all active powers
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated() and p_name != power_name]
# Prioritize based on: 1) Powers that messaged us directly, 2) Most active powers
priority_list = []
# First, add powers that sent us direct messages (sorted by count)
direct_senders = sorted(direct_messages_to_us.items(), key=lambda x: x[1], reverse=True)
for sender, _ in direct_senders:
if sender in active_powers:
priority_list.append(sender)
# Then add other active powers (sorted by activity)
remaining_powers = [p for p in active_powers if p not in priority_list]
activity_sorted = sorted(remaining_powers, key=lambda p: message_activity.get(p, 0), reverse=True)
priority_list.extend(activity_sorted)
logger.debug(f"Message targeting priority for {power_name}: {priority_list}")
return priority_list
async def conduct_strategic_negotiation_round(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
round_number: int,
max_rounds: int = 3,
) -> bool:
"""
Conduct a single negotiation round with strategic message targeting.
This function analyzes recent message activity to determine which powers
to prioritize for messaging in this round.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
model_error_stats: Error statistics tracking
log_file_path: Path for logging
round_number: Current round number (1-indexed)
max_rounds: Maximum number of negotiation rounds
Returns:
True if messages were sent successfully, False otherwise
"""
power_name = agent.power_name
# Analyze recent messages to prioritize targets
priority_targets = await analyze_recent_messages_for_targeting(client, power_name)
# Limit to top 3-4 targets in later rounds to focus conversations
if round_number > 1:
priority_targets = priority_targets[: min(4, len(priority_targets))]
logger.info(f"Round {round_number} targets for {power_name}: {priority_targets}")
# Conduct negotiation with prioritized targets
return await conduct_single_bot_negotiation(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=model_error_stats,
log_file_path=log_file_path,
max_rounds=max_rounds,
round_number=round_number,
prioritize_targets=priority_targets,
)