WIP: Exposing errors that cause bot to fail

Working to remove all the lower level try except statements that catch
barely, and don't offer any information to top level about what went
wrong. Better to throw exceptions and let the consumers catch if needed.
This commit is contained in:
Tyler Marques 2025-06-24 19:12:22 -07:00
parent f37c344213
commit c09d5bfe7e
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 1013 additions and 400 deletions

3
.vscode/launch.json vendored
View file

@ -105,7 +105,7 @@
"--game-id",
"${input:gameId}",
"--log-level",
"INFO"
"DEBUG"
],
"console": "integratedTerminal",
"justMyCode": true,
@ -138,6 +138,7 @@
"type": "pickString",
"description": "Select the AI model to use",
"options": [
"gemini-2.5-flash-lite-preview-06-17",
"gpt-3.5-turbo",
"gpt-4",
"claude-3-haiku",

View file

@ -5,6 +5,7 @@ import re
import json_repair
import json5
from ai_diplomacy.game_history import GameHistory
from diplomacy.engine.game import Game # More forgiving JSON parser
# Assuming BaseModelClient is importable from clients.py in the same directory
@ -1150,6 +1151,9 @@ class DiplomacyAgent:
game_history: "GameHistory",
log_file_path: str | Path,
):
assert GameHistory is not None, (
"GameHistory object must be initialized before calling analyze_phase_and_update_state"
)
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
# Use self.power_name internally
power_name = self.power_name
@ -1169,7 +1173,7 @@ class DiplomacyAgent:
return
# Get previous phase safely from history
if not game_history or not game_history.phases:
if not game_history.phases:
logger.warning(
f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update."
)
@ -1254,9 +1258,8 @@ class DiplomacyAgent:
log_entry_success = "FALSE" # Default
update_data = None # Initialize
if (
response is not None and response.strip()
): # Check if response is not None and not just whitespace
# Check if response is not None and not just whitespace
if response is not None and response.strip():
try:
update_data = self._extract_json_from_text(response)
logger.debug(

View file

@ -15,7 +15,6 @@ from anthropic import AsyncAnthropic
import google.generativeai as genai
from diplomacy.engine.message import GLOBAL
from .game_history import GameHistory
from .utils import (
load_prompt,
@ -25,7 +24,6 @@ from .utils import (
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
# from .agent import DiplomacyAgent
from .possible_order_context import generate_rich_order_context
from .prompt_constructor import (
construct_order_generation_prompt,
build_context_prompt,
@ -113,7 +111,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log(
client=self,
prompt=prompt,
log_file_path=log_file_path,
power_name=power_name,
phase=phase,
response_type="order", # Context for run_llm_and_log's own error logging
@ -553,7 +550,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log(
client=self,
prompt=prompt,
log_file_path=log_file_path,
power_name=power_name,
phase=game_phase, # Use game_phase for logging
response_type="plan_reply", # Changed from 'plan' to avoid confusion
@ -604,7 +600,6 @@ class BaseModelClient:
raw_response = await run_llm_and_log(
client=self,
prompt=raw_input_prompt,
log_file_path=log_file_path,
power_name=power_name,
phase=game_phase,
response_type="negotiation", # For run_llm_and_log's internal context
@ -829,7 +824,6 @@ class BaseModelClient:
raw_plan_response = await run_llm_and_log(
client=self, # Pass self (the client instance)
prompt=full_prompt,
log_file_path=log_file_path,
power_name=power_name,
phase=game.current_short_phase,
response_type="plan_generation", # More specific type for run_llm_and_log context
@ -1328,4 +1322,3 @@ def get_visible_messages_for_power(conversation_messages, power_name):
):
visible.append(msg)
return visible # already in chronological order if appended that way

View file

@ -3,7 +3,7 @@ import logging
import re
from typing import TYPE_CHECKING, Optional
from .utils import run_llm_and_log, log_llm_response, load_prompt
from .utils import run_llm_and_log, log_llm_response
if TYPE_CHECKING:
from diplomacy import Game
@ -12,6 +12,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation(
agent: "DiplomacyAgent",
game: "Game",
@ -23,19 +42,32 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class.
"""
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
logger.info(
f"[{agent.power_name}] CONSOLIDATION START — "
f"{len(agent.full_private_diary)} total full entries"
)
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
full_entries = [
e
for e in agent.full_private_diary
if not e.startswith("[CONSOLIDATED HISTORY]")
]
if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary)
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
logger.info(
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
"skipping consolidation"
)
return
boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match:
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
logger.error(
f"[{agent.power_name}] Could not parse year from boundary entry; "
"aborting consolidation"
)
agent.private_diary = list(agent.full_private_diary)
return
@ -46,19 +78,37 @@ async def run_diary_consolidation(
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
entries_to_summarize = [
e
for e in full_entries
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
]
entries_to_keep = [
e
for e in full_entries
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
]
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
logger.info(
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
f"keeping {len(entries_to_keep)} recent entries verbatim"
)
if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary)
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
logger.warning(
f"[{agent.power_name}] No eligible entries to summarise; "
"context diary left unchanged"
)
return
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
prompt_template = _load_prompt_file(
"diary_consolidation_prompt.txt", prompts_dir=prompts_dir
)
if not prompt_template:
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
logger.error(
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
)
return
prompt = prompt_template.format(
@ -87,14 +137,23 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE"
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
logger.info(
f"[{agent.power_name}] Consolidation complete — "
f"{len(agent.private_diary)} context entries now"
)
except Exception as exc:
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
logger.error(
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
)
finally:
log_llm_response(
log_file_path=log_file_path,
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
model_name=(
consolidation_client.model_name
if consolidation_client is not None
else agent.client.model_name
),
power_name=agent.power_name,
phase=game.current_short_phase,
response_type="diary_consolidation",
@ -102,3 +161,4 @@ async def run_diary_consolidation(
raw_response=raw_response,
success=success_flag,
)

View file

@ -148,7 +148,6 @@ class GameHistory:
return {}
return self.phases[-1].plans
# NEW METHOD
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None
for phase_obj in self.phases:

View file

@ -1,4 +1,4 @@
from os.path import exists
import datetime
from pydantic_settings import BaseSettings
from pathlib import Path
import warnings
@ -6,16 +6,16 @@ import warnings
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path = Path("./logs/logs.txt")
log_file_path: Path
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None
def __init__(self, **kwargs):
def __init__(self, power_name, **kwargs):
super().__init__(**kwargs)
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
# Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
@ -42,11 +42,7 @@ class Configuration(BaseSettings):
"""Override to check for empty API keys at access time"""
value = super().__getattribute__(name)
if name.endswith("_KEY") and (
not value or (isinstance(value, str) and len(value) == 0)
):
raise ValueError(
f"API key '{name}' is not set or is empty. Please configure it before use."
)
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
return value

View file

@ -30,12 +30,9 @@ from websocket_diplomacy_client import (
from models import (
SignInRequest,
CreateGameRequest,
JoinGameRequest,
SetOrdersRequest,
ListGamesRequest,
ProcessGameRequest,
serialize_message,
parse_message,
)
import sys
@ -64,9 +61,7 @@ logging.getLogger("root").setLevel(logging.WARNING)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a Diplomacy game simulation with WebSocket server connection."
)
parser = argparse.ArgumentParser(description="Run a Diplomacy game simulation with WebSocket server connection.")
parser.add_argument(
"--hostname",
type=str,
@ -137,9 +132,7 @@ def parse_arguments():
return parser.parse_args()
async def join_powers_for_testing(
client: WebSocketDiplomacyClient, power_model_map: dict
):
async def join_powers_for_testing(client: WebSocketDiplomacyClient, power_model_map: dict):
"""
Join multiple powers in the same game for testing purposes.
This simulates having multiple AI players in one game.
@ -150,17 +143,13 @@ async def join_powers_for_testing(
for power_name in power_names[1:]:
try:
logger.info(f"Attempting to join power {power_name}")
await client.channel.join_game(
game_id=client.game_id, power_name=power_name
)
await client.channel.join_game(game_id=client.game_id, power_name=power_name)
logger.info(f"Successfully joined {power_name}")
except Exception as e:
logger.warning(f"Could not join {power_name}: {e}")
async def create_or_join_game(
client: WebSocketDiplomacyClient, args, power_model_map: dict
):
async def create_or_join_game(client: WebSocketDiplomacyClient, args, power_model_map: dict):
"""
Create a new game or join an existing one based on arguments.
"""
@ -171,9 +160,7 @@ async def create_or_join_game(
# List available games first to see what's available
try:
games = await client.list_games()
logger.info(
f"Available games: {[g.get('game_id', 'unknown') for g in games]}"
)
logger.info(f"Available games: {[g.get('game_id', 'unknown') for g in games]}")
except Exception as e:
logger.warning(f"Could not list games: {e}")
@ -189,19 +176,13 @@ async def create_or_join_game(
logger.info("Creating new game")
# Get the first power to control
first_power = (
list(power_model_map.keys())[0]
if not args.create_multi_power_game
else None
)
first_power = list(power_model_map.keys())[0] if not args.create_multi_power_game else None
game = await client.create_game(
map_name="standard",
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
power_name=first_power,
n_controls=7
if not args.create_multi_power_game
else 1, # Lower requirement for testing
n_controls=7 if not args.create_multi_power_game else 1, # Lower requirement for testing
deadline=None, # No time pressure for AI testing
)
@ -215,21 +196,17 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
"""
Demonstrate how to use typed WebSocket messages.
This shows the proper way to construct messages according to the protocol.
NOTE: This is for demonstration only - the current client implementation
uses the diplomacy library's built-in networking which abstracts the raw protocol.
"""
logger.info("=== Typed Message Examples ===")
# Example 1: Authentication Request
sign_in_request = SignInRequest(
request_id=str(uuid.uuid4()),
username=args.username,
password=args.password
)
sign_in_request = SignInRequest(request_id=str(uuid.uuid4()), username=args.username, password=args.password)
logger.info(f"SignIn request: {serialize_message(sign_in_request)}")
# Example 2: Create Game Request
create_game_request = CreateGameRequest(
request_id=str(uuid.uuid4()),
@ -237,29 +214,25 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
map_name="standard",
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
n_controls=1,
power_name="FRANCE"
power_name="FRANCE",
)
logger.info(f"CreateGame request: {serialize_message(create_game_request)}")
# Example 3: List Games Request
list_games_request = ListGamesRequest(
request_id=str(uuid.uuid4()),
token="example_auth_token_here",
include_protected=False
)
list_games_request = ListGamesRequest(request_id=str(uuid.uuid4()), token="example_auth_token_here", include_protected=False)
logger.info(f"ListGames request: {serialize_message(list_games_request)}")
# Example 4: Set Orders Request (would be used in game loop)
if hasattr(client, 'game_id') and hasattr(client, 'token'):
if hasattr(client, "game_id") and hasattr(client, "token"):
set_orders_request = SetOrdersRequest(
request_id=str(uuid.uuid4()),
token=client.token or "example_token",
game_id=client.game_id or "example_game_id",
game_role="FRANCE",
orders=["A PAR H", "F BRE H", "A MAR H"]
orders=["A PAR H", "F BRE H", "A MAR H"],
)
logger.info(f"SetOrders request: {serialize_message(set_orders_request)}")
logger.info("=== End Typed Message Examples ===")
@ -282,9 +255,7 @@ async def main():
logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs")
start_whole = time.time()
model_error_stats = defaultdict(
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
)
model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
# Determine the result folder based on a timestamp
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
@ -306,9 +277,7 @@ async def main():
# File paths
manifesto_path = f"{result_folder}/game_manifesto.txt"
game_file_path = (
args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
)
game_file_path = args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
overview_file_path = f"{result_folder}/overview.jsonl"
llm_log_file_path = f"{result_folder}/llm_responses.csv"
@ -325,9 +294,7 @@ async def main():
]
provided_models = [name.strip() for name in args.models.split(",")]
if len(provided_models) != len(powers_order):
logger.error(
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting."
)
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting.")
return
power_model_map = dict(zip(powers_order, provided_models))
else:
@ -377,32 +344,20 @@ async def main():
client_obj = load_model_client(model_id)
agent = DiplomacyAgent(power_name=power_name, client=client_obj)
agents[power_name] = agent
logger.info(
f"Preparing initialization task for {power_name} with model {model_id}"
)
initialization_tasks.append(
initialize_agent_state_ext(
agent, client.game, game_history, llm_log_file_path
)
)
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
initialization_tasks.append(initialize_agent_state_ext(agent, client.game, game_history, llm_log_file_path))
except Exception as e:
logger.error(
f"Failed to create agent for {power_name} with model {model_id}: {e}",
exc_info=True,
)
else:
logger.info(
f"Skipping agent initialization for {power_name} (no model or eliminated)"
)
logger.info(f"Skipping agent initialization for {power_name} (no model or eliminated)")
# Run initializations concurrently
if initialization_tasks:
logger.info(
f"Running {len(initialization_tasks)} agent initializations concurrently..."
)
initialization_results = await asyncio.gather(
*initialization_tasks, return_exceptions=True
)
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
initialized_powers = list(agents.keys())
for i, result in enumerate(initialization_results):
@ -414,9 +369,7 @@ async def main():
exc_info=result,
)
else:
logger.info(
f"Successfully initialized agent state for {power_name}."
)
logger.info(f"Successfully initialized agent state for {power_name}.")
# Main game loop
all_phase_relationships = {}
@ -433,9 +386,7 @@ async def main():
game_history.add_phase(current_phase)
current_short_phase = client.get_current_short_phase()
logger.info(
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
)
logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
# Prevent unbounded simulation based on year
year_str = current_phase[1:5]
@ -447,9 +398,7 @@ async def main():
# Negotiations for movement phases
if client.get_current_short_phase().endswith("M"):
if args.num_negotiation_rounds > 0:
logger.info(
f"Running {args.num_negotiation_rounds} rounds of negotiations..."
)
logger.info(f"Running {args.num_negotiation_rounds} rounds of negotiations...")
game_history = await conduct_negotiations(
client.game, # Pass the NetworkGame object
agents,
@ -459,9 +408,7 @@ async def main():
log_file_path=llm_log_file_path,
)
else:
logger.info(
"Skipping negotiation phase as num_negotiation_rounds=0"
)
logger.info("Skipping negotiation phase as num_negotiation_rounds=0")
# Planning phase (if enabled)
if args.planning_phase:
@ -475,29 +422,19 @@ async def main():
)
# Generate negotiation diary entries
logger.info(
f"Generating negotiation diary entries for phase {current_short_phase}..."
)
active_powers_for_neg_diary = [
p for p in agents.keys() if not client.get_power(p).is_eliminated()
]
logger.info(f"Generating negotiation diary entries for phase {current_short_phase}...")
active_powers_for_neg_diary = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
neg_diary_tasks = []
for power_name, agent in agents.items():
if not client.get_power(power_name).is_eliminated():
neg_diary_tasks.append(
agent.generate_negotiation_diary_entry(
client.game, game_history, llm_log_file_path
)
)
neg_diary_tasks.append(agent.generate_negotiation_diary_entry(client.game, game_history, llm_log_file_path))
if neg_diary_tasks:
await asyncio.gather(*neg_diary_tasks, return_exceptions=True)
# AI Decision Making: Get orders for each controlled power
logger.info("Getting orders from agents...")
active_powers_for_orders = [
p for p in agents.keys() if not client.get_power(p).is_eliminated()
]
active_powers_for_orders = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
order_tasks = []
order_power_names = []
@ -505,35 +442,23 @@ async def main():
for power_name, agent in agents.items():
if client.get_power(power_name).is_eliminated():
logger.debug(
f"Skipping order generation for eliminated power {power_name}."
)
logger.debug(f"Skipping order generation for eliminated power {power_name}.")
continue
# Diagnostic logging
logger.info(
f"--- Diagnostic Log for {power_name} in phase {current_phase} ---"
)
logger.info(f"--- Diagnostic Log for {power_name} in phase {current_phase} ---")
try:
orderable_locs = client.get_orderable_locations(power_name)
logger.info(
f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}"
)
logger.info(f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}")
actual_units = client.get_units(power_name)
logger.info(
f"[{power_name}][{current_phase}] Actual units: {actual_units}"
)
logger.info(f"[{power_name}][{current_phase}] Actual units: {actual_units}")
except Exception as e_diag:
logger.error(
f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}"
)
logger.error(f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}")
# Calculate possible orders
possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders:
logger.debug(
f"No orderable locations for {power_name}; submitting empty orders."
)
logger.debug(f"No orderable locations for {power_name}; submitting empty orders.")
await client.set_orders(power_name, [])
continue
@ -559,12 +484,8 @@ async def main():
# Run order generation concurrently
if order_tasks:
logger.debug(
f"Running {len(order_tasks)} order generation tasks concurrently..."
)
order_results = await asyncio.gather(
*order_tasks, return_exceptions=True
)
logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...")
order_results = await asyncio.gather(*order_tasks, return_exceptions=True)
else:
order_results = []
@ -580,9 +501,7 @@ async def main():
)
await client.set_orders(p_name, [])
elif result is None:
logger.warning(
f"get_valid_orders returned None for {p_name}. Setting empty orders."
)
logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.")
await client.set_orders(p_name, [])
else:
orders = result
@ -598,17 +517,13 @@ async def main():
# orders=orders
# )
# await client.send_typed_message(set_orders_msg)
await client.set_orders(p_name, orders)
logger.debug(
f"Set orders for {p_name} in {current_short_phase}: {orders}"
)
logger.debug(f"Set orders for {p_name} in {current_short_phase}: {orders}")
# Generate order diary entry
try:
await agent.generate_order_diary_entry(
client.game, orders, llm_log_file_path
)
await agent.generate_order_diary_entry(client.game, orders, llm_log_file_path)
except Exception as e_diary:
logger.error(
f"Error generating order diary for {p_name}: {e_diary}",
@ -627,7 +542,7 @@ async def main():
# phase=current_phase
# )
# await client.send_typed_message(process_msg)
logger.info(f"Processing orders for {current_phase}...")
await simulate_game_processing(client)
@ -650,22 +565,13 @@ async def main():
# Collect relationships for this phase
current_relationships_for_phase = {}
for power_name, agent in agents.items():
if (
power_name in client.powers
and not client.get_power(power_name).is_eliminated()
):
if power_name in client.powers and not client.get_power(power_name).is_eliminated():
current_relationships_for_phase[power_name] = agent.relationships
all_phase_relationships[current_short_phase] = (
current_relationships_for_phase
)
all_phase_relationships[current_short_phase] = current_relationships_for_phase
# Generate phase result diary entries
logger.info(
f"Generating phase result diary entries for completed phase {current_phase}..."
)
phase_summary = getattr(client.game, "phase_summaries", {}).get(
current_phase, "(Summary not generated)"
)
logger.info(f"Generating phase result diary entries for completed phase {current_phase}...")
phase_summary = getattr(client.game, "phase_summaries", {}).get(current_phase, "(Summary not generated)")
all_orders_this_phase = current_order_history
phase_result_diary_tasks = []
@ -685,16 +591,10 @@ async def main():
await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True)
# State update analysis
logger.info(
f"Starting state update analysis for completed phase {current_phase}..."
)
logger.info(f"Starting state update analysis for completed phase {current_phase}...")
current_board_state = client.get_state()
active_agent_powers = [
(p, power)
for p, power in client.powers.items()
if p in agents and not power.is_eliminated()
]
active_agent_powers = [(p, power) for p, power in client.powers.items() if p in agents and not power.is_eliminated()]
if active_agent_powers:
state_update_tasks = []
@ -781,4 +681,3 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())

View file

@ -8,34 +8,39 @@ and notifications (server -> client).
from typing import Optional, List, Dict, Any, Union, Literal
from pydantic import BaseModel, Field
from abc import ABC, abstractmethod
from abc import ABC
# =============================================================================
# Base Message Classes
# =============================================================================
class BaseMessage(BaseModel, ABC):
"""Base class for all WebSocket messages."""
name: str
class Config:
extra = "forbid"
class BaseRequest(BaseMessage):
"""Base class for all client -> server requests."""
request_id: str
re_sent: bool = False
class BaseResponse(BaseMessage):
"""Base class for all server -> client responses."""
request_id: str
class BaseNotification(BaseMessage):
"""Base class for all server -> client notifications (no request_id)."""
pass
@ -43,8 +48,10 @@ class BaseNotification(BaseMessage):
# Authentication & Connection Level Messages
# =============================================================================
class SignInRequest(BaseRequest):
"""Client authentication request."""
name: Literal["sign_in"] = "sign_in"
username: str
password: str
@ -52,6 +59,7 @@ class SignInRequest(BaseRequest):
class GetDaidePortRequest(BaseRequest):
"""Request DAIDE TCP port for a game."""
name: Literal["get_daide_port"] = "get_daide_port"
game_id: str
@ -60,13 +68,16 @@ class GetDaidePortRequest(BaseRequest):
# Channel Level Messages (require authentication token)
# =============================================================================
class ChannelRequest(BaseRequest):
"""Base class for channel-level requests that require authentication."""
token: str
class CreateGameRequest(ChannelRequest):
"""Create a new game."""
name: Literal["create_game"] = "create_game"
map_name: str = "standard"
rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"])
@ -78,6 +89,7 @@ class CreateGameRequest(ChannelRequest):
class JoinGameRequest(ChannelRequest):
"""Join an existing game."""
name: Literal["join_game"] = "join_game"
game_id: str
power_name: Optional[str] = None
@ -86,6 +98,7 @@ class JoinGameRequest(ChannelRequest):
class JoinPowersRequest(ChannelRequest):
"""Join multiple powers in a game."""
name: Literal["join_powers"] = "join_powers"
game_id: str
power_names: List[str]
@ -94,32 +107,37 @@ class JoinPowersRequest(ChannelRequest):
class ListGamesRequest(ChannelRequest):
"""List available games."""
name: Literal["list_games"] = "list_games"
game_id_filter: Optional[str] = None
map_name: Optional[str] = None
status: Optional[str] = None
status: Optional[str] = None
include_protected: bool = False
class GetPlayablePowersRequest(ChannelRequest):
"""Get uncontrolled powers in a game."""
name: Literal["get_playable_powers"] = "get_playable_powers"
game_id: str
class GetAvailableMapsRequest(ChannelRequest):
"""Get list of available maps."""
name: Literal["get_available_maps"] = "get_available_maps"
class GetDummyWaitingPowersRequest(ChannelRequest):
"""Get AI-controllable powers (bot use)."""
name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers"
game_id: str
class SetGradeRequest(ChannelRequest):
"""Modify user permissions."""
name: Literal["set_grade"] = "set_grade"
username: str
grade: str
@ -127,11 +145,13 @@ class SetGradeRequest(ChannelRequest):
class DeleteAccountRequest(ChannelRequest):
"""Delete user account."""
name: Literal["delete_account"] = "delete_account"
class LogoutRequest(ChannelRequest):
"""Disconnect from server."""
name: Literal["logout"] = "logout"
@ -139,8 +159,10 @@ class LogoutRequest(ChannelRequest):
# Game Level Messages (require authentication + game context)
# =============================================================================
class GameRequest(ChannelRequest):
"""Base class for game-level requests."""
game_id: str
game_role: str # Power name like "ENGLAND"
phase: Optional[str] = None
@ -148,18 +170,21 @@ class GameRequest(ChannelRequest):
class SetOrdersRequest(GameRequest):
"""Submit orders for a power."""
name: Literal["set_orders"] = "set_orders"
orders: List[str]
class SetWaitFlagRequest(GameRequest):
"""Set wait flag for turn processing."""
name: Literal["set_wait_flag"] = "set_wait_flag"
wait: bool
class SendGameMessageRequest(GameRequest):
"""Send diplomatic message."""
name: Literal["send_game_message"] = "send_game_message"
recipient: str # Power name or "GLOBAL"
message: str
@ -168,11 +193,13 @@ class SendGameMessageRequest(GameRequest):
class GetAllPossibleOrdersRequest(GameRequest):
"""Get legal orders for current phase."""
name: Literal["get_all_possible_orders"] = "get_all_possible_orders"
class GetPhaseHistoryRequest(GameRequest):
"""Get historical game phases."""
name: Literal["get_phase_history"] = "get_phase_history"
from_phase: Optional[str] = None
to_phase: Optional[str] = None
@ -180,45 +207,53 @@ class GetPhaseHistoryRequest(GameRequest):
class ProcessGameRequest(GameRequest):
"""Force game processing (master only)."""
name: Literal["process_game"] = "process_game"
class VoteRequest(GameRequest):
"""Vote for/against draw."""
name: Literal["vote"] = "vote"
vote: Literal["yes", "no"]
class SaveGameRequest(GameRequest):
"""Export game as JSON."""
name: Literal["save_game"] = "save_game"
class SetGameStateRequest(GameRequest):
"""Modify game state (master only)."""
name: Literal["set_game_state"] = "set_game_state"
state: Dict[str, Any]
class SetGameStatusRequest(GameRequest):
"""Change game status (master only)."""
name: Literal["set_game_status"] = "set_game_status"
status: str
class SetDummyPowersRequest(GameRequest):
"""Make powers AI-controlled (master only)."""
name: Literal["set_dummy_powers"] = "set_dummy_powers"
power_names: List[str]
class DeleteGameRequest(GameRequest):
"""Delete game (master only)."""
name: Literal["delete_game"] = "delete_game"
class LeaveGameRequest(GameRequest):
"""Leave game."""
name: Literal["leave_game"] = "leave_game"
@ -226,13 +261,16 @@ class LeaveGameRequest(GameRequest):
# Response Messages (Server -> Client)
# =============================================================================
class OkResponse(BaseResponse):
"""Generic success response."""
name: Literal["ok"] = "ok"
class ErrorResponse(BaseResponse):
"""Error response with error type and message."""
name: Literal["error"] = "error"
error_type: str
message: str
@ -240,60 +278,70 @@ class ErrorResponse(BaseResponse):
class DataTokenResponse(BaseResponse):
"""Contains authentication token."""
name: Literal["data_token"] = "data_token"
data: str # The authentication token
class DataGameResponse(BaseResponse):
"""Contains full game object."""
name: Literal["data_game"] = "data_game"
data: Dict[str, Any] # The complete game state
class DataGameInfoResponse(BaseResponse):
"""Contains game metadata."""
name: Literal["data_game_info"] = "data_game_info"
data: Dict[str, Any]
class DataGamesResponse(BaseResponse):
"""List of game information."""
name: Literal["data_games"] = "data_games"
data: List[Dict[str, Any]]
class DataMapsResponse(BaseResponse):
"""Available maps information."""
name: Literal["data_maps"] = "data_maps"
data: List[str]
class DataPowerNamesResponse(BaseResponse):
"""List of power names."""
name: Literal["data_power_names"] = "data_power_names"
data: List[str]
class DataPossibleOrdersResponse(BaseResponse):
"""Legal orders and locations."""
name: Literal["data_possible_orders"] = "data_possible_orders"
data: Dict[str, List[str]] # Location -> list of possible orders
class DataGamePhasesResponse(BaseResponse):
"""Historical game phases."""
name: Literal["data_game_phases"] = "data_game_phases"
data: List[Dict[str, Any]]
class DataSavedGameResponse(BaseResponse):
"""Exported game JSON."""
name: Literal["data_saved_game"] = "data_saved_game"
data: Dict[str, Any]
class DataPortResponse(BaseResponse):
"""DAIDE port number."""
name: Literal["data_port"] = "data_port"
data: int
@ -302,8 +350,10 @@ class DataPortResponse(BaseResponse):
# Notification Messages (Server -> Client)
# =============================================================================
class GameProcessedNotification(BaseNotification):
"""Phase completed, new orders phase."""
name: Literal["game_processed"] = "game_processed"
game_id: str
phase: str
@ -312,6 +362,7 @@ class GameProcessedNotification(BaseNotification):
class GamePhaseUpdateNotification(BaseNotification):
"""Game state changed."""
name: Literal["game_phase_update"] = "game_phase_update"
game_id: str
phase: str
@ -320,6 +371,7 @@ class GamePhaseUpdateNotification(BaseNotification):
class GameStatusUpdateNotification(BaseNotification):
"""Game status changed (forming/active/paused/completed)."""
name: Literal["game_status_update"] = "game_status_update"
game_id: str
status: str
@ -327,6 +379,7 @@ class GameStatusUpdateNotification(BaseNotification):
class PowersControllersNotification(BaseNotification):
"""Power control assignments changed."""
name: Literal["powers_controllers"] = "powers_controllers"
game_id: str
controllers: Dict[str, str] # Power -> Controller mapping
@ -334,6 +387,7 @@ class PowersControllersNotification(BaseNotification):
class PowerOrdersUpdateNotification(BaseNotification):
"""Player submitted new orders."""
name: Literal["power_orders_update"] = "power_orders_update"
game_id: str
power_name: str
@ -343,15 +397,17 @@ class PowerOrdersUpdateNotification(BaseNotification):
class PowerOrdersFlagNotification(BaseNotification):
"""Player order submission status."""
name: Literal["power_orders_flag"] = "power_orders_flag"
game_id: str
power_name: str
power_name: str
order_is_set: bool
phase: str
class PowerWaitFlagNotification(BaseNotification):
"""Player wait flag changed."""
name: Literal["power_wait_flag"] = "power_wait_flag"
game_id: str
power_name: str
@ -360,6 +416,7 @@ class PowerWaitFlagNotification(BaseNotification):
class GameMessageReceivedNotification(BaseNotification):
"""Diplomatic message received."""
name: Literal["game_message_received"] = "game_message_received"
game_id: str
sender: str
@ -371,6 +428,7 @@ class GameMessageReceivedNotification(BaseNotification):
class VoteUpdatedNotification(BaseNotification):
"""Draw votes changed (omniscient view)."""
name: Literal["vote_updated"] = "vote_updated"
game_id: str
votes: Dict[str, str] # Power -> vote mapping
@ -378,6 +436,7 @@ class VoteUpdatedNotification(BaseNotification):
class VoteCountUpdatedNotification(BaseNotification):
"""Vote count changed (observer view)."""
name: Literal["vote_count_updated"] = "vote_count_updated"
game_id: str
count_yes: int
@ -386,6 +445,7 @@ class VoteCountUpdatedNotification(BaseNotification):
class PowerVoteUpdatedNotification(BaseNotification):
"""Own power's vote changed."""
name: Literal["power_vote_updated"] = "power_vote_updated"
game_id: str
power_name: str
@ -394,12 +454,14 @@ class PowerVoteUpdatedNotification(BaseNotification):
class GameDeletedNotification(BaseNotification):
"""Game removed from server."""
name: Literal["game_deleted"] = "game_deleted"
game_id: str
class OmniscientUpdatedNotification(BaseNotification):
"""Observer permissions changed."""
name: Literal["omniscient_updated"] = "omniscient_updated"
game_id: str
omniscient_type: str
@ -407,12 +469,14 @@ class OmniscientUpdatedNotification(BaseNotification):
class AccountDeletedNotification(BaseNotification):
"""User account deleted."""
name: Literal["account_deleted"] = "account_deleted"
username: str
class ClearedCentersNotification(BaseNotification):
"""Supply centers cleared."""
name: Literal["cleared_centers"] = "cleared_centers"
game_id: str
power_name: str
@ -420,6 +484,7 @@ class ClearedCentersNotification(BaseNotification):
class ClearedOrdersNotification(BaseNotification):
"""Orders cleared."""
name: Literal["cleared_orders"] = "cleared_orders"
game_id: str
power_name: str
@ -428,6 +493,7 @@ class ClearedOrdersNotification(BaseNotification):
class ClearedUnitsNotification(BaseNotification):
"""Units cleared."""
name: Literal["cleared_units"] = "cleared_units"
game_id: str
power_name: str
@ -514,24 +580,25 @@ WebSocketMessage = Union[RequestMessage, ResponseMessage, NotificationMessage]
# Utility Functions
# =============================================================================
def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
"""
Parse a raw WebSocket message dictionary into the appropriate pydantic model.
Args:
data: Raw message dictionary from WebSocket
Returns:
Parsed message object
Raises:
ValueError: If message cannot be parsed or is of unknown type
"""
if not isinstance(data, dict) or "name" not in data:
raise ValueError("Invalid message format: missing 'name' field")
message_name = data["name"]
# Map message names to their corresponding classes
message_classes = {
# Requests
@ -592,22 +659,22 @@ def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
"cleared_orders": ClearedOrdersNotification,
"cleared_units": ClearedUnitsNotification,
}
message_class = message_classes.get(message_name)
if message_class is None:
raise ValueError(f"Unknown message type: {message_name}")
return message_class(**data)
def serialize_message(message: WebSocketMessage) -> Dict[str, Any]:
"""
Serialize a pydantic message object to a dictionary for WebSocket transmission.
Args:
message: Pydantic message object
Returns:
Dictionary representation of the message
"""
return message.model_dump(exclude_none=True)
return message.model_dump(exclude_none=True)

View file

@ -50,13 +50,13 @@ class MultiBotLauncher:
# Default power to model mapping
self.default_models = {
"AUSTRIA": "gpt-3.5-turbo",
"ENGLAND": "gpt-4",
"FRANCE": "claude-3-haiku",
"GERMANY": "gpt-3.5-turbo",
"ITALY": "gemini-pro",
"RUSSIA": "gpt-4",
"TURKEY": "claude-3-sonnet",
"AUSTRIA": "gemini-2.5-flash-lite-preview-06-17",
"ENGLAND": "gemini-2.5-flash-lite-preview-06-17",
"FRANCE": "gemini-2.5-flash-lite-preview-06-17",
"GERMANY": "gemini-2.5-flash-lite-preview-06-17",
"ITALY": "gemini-2.5-flash-lite-preview-06-17",
"RUSSIA": "gemini-2.5-flash-lite-preview-06-17",
"TURKEY": "gemini-2.5-flash-lite-preview-06-17",
}
async def create_game(self, creator_power: str = "FRANCE") -> str:
@ -99,7 +99,12 @@ class MultiBotLauncher:
return game_id
def launch_bot(
self, power: str, model: str, game_id: str, log_level: str = "INFO"
self,
power: str,
model: str,
game_id: str,
log_level: str = "INFO",
negotiation_rounds: int = 3,
) -> subprocess.Popen:
"""
Launch a single bot process.
@ -134,6 +139,8 @@ class MultiBotLauncher:
game_id,
"--log-level",
log_level,
"--negotiation-rounds",
str(negotiation_rounds),
]
logger.info(f"Launching bot for {power} with model {model}")
@ -156,7 +163,8 @@ class MultiBotLauncher:
models: Optional[Dict[str, str]] = None,
powers: Optional[List[str]] = None,
log_level: str = "INFO",
stagger_delay: float = 2.0,
stagger_delay: float = 0.5,
negotiation_rounds: int = 3,
):
"""
Launch bots for all specified powers.
@ -180,13 +188,11 @@ class MultiBotLauncher:
model = models.get(power, "gpt-3.5-turbo")
try:
process = self.launch_bot(power, model, game_id, log_level)
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
self.bot_processes.append(process)
self.process_to_power[process] = power
logger.info(
f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})"
)
logger.info(f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})")
# Stagger the launches to avoid overwhelming the server
if i < len(powers) - 1: # Don't delay after the last bot
@ -209,11 +215,11 @@ class MultiBotLauncher:
try:
while self.bot_processes:
active_processes = []
# Collect all stdout file descriptors from active processes
stdout_fds = []
fd_to_process = {}
for process in self.bot_processes:
if process.poll() is None: # Still running
active_processes.append(process)
@ -224,15 +230,13 @@ class MultiBotLauncher:
# Process has ended
return_code = process.returncode
power = self.process_to_power.get(process, "UNKNOWN")
logger.info(
f"{power} bot process {process.pid} ended with code {return_code}"
)
logger.info(f"{power} bot process {process.pid} ended with code {return_code}")
# Read any remaining output
remaining_output = process.stdout.read()
if remaining_output:
print(f"{power}_{process.pid} final output: {remaining_output}")
# Clean up the power mapping
self.process_to_power.pop(process, None)
@ -241,20 +245,20 @@ class MultiBotLauncher:
if not self.bot_processes:
logger.info("All bots have finished")
break
# Use select to check which processes have output ready (Unix only)
if stdout_fds and hasattr(select, 'select'):
if stdout_fds and hasattr(select, "select"):
try:
ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout
for fd in ready_fds:
process = fd_to_process[fd]
power = self.process_to_power.get(process, "UNKNOWN")
# Read available lines (but limit to prevent monopolizing)
lines_read = 0
max_lines_per_process = 10
while lines_read < max_lines_per_process:
try:
line = process.stdout.readline()
@ -264,7 +268,7 @@ class MultiBotLauncher:
lines_read += 1
except:
break
except (OSError, ValueError):
# Fallback if select fails
self._fallback_read_output(active_processes)
@ -283,11 +287,11 @@ class MultiBotLauncher:
"""Fallback method for reading output when select is not available."""
for process in active_processes:
power = self.process_to_power.get(process, "UNKNOWN")
# Read limited lines per process to prevent monopolizing
lines_read = 0
max_lines_per_process = 3 # More conservative for fallback
while lines_read < max_lines_per_process:
try:
line = process.stdout.readline()
@ -324,6 +328,7 @@ class MultiBotLauncher:
models: Optional[Dict[str, str]] = None,
log_level: str = "INFO",
creator_power: str = "FRANCE",
negotiation_rounds: int = 3,
):
"""
Create a game and launch all bots for a complete game.
@ -342,7 +347,12 @@ class MultiBotLauncher:
await asyncio.sleep(2)
# Launch all bots
await self.launch_all_bots(game_id, models, log_level=log_level)
await self.launch_all_bots(
game_id,
models,
log_level=log_level,
negotiation_rounds=negotiation_rounds,
)
# Monitor the bots
self.monitor_bots()
@ -358,6 +368,7 @@ class MultiBotLauncher:
powers: List[str],
models: Optional[Dict[str, str]] = None,
log_level: str = "INFO",
negotiation_rounds: int = 3,
):
"""
Launch bots to join an existing game.
@ -372,7 +383,13 @@ class MultiBotLauncher:
self.game_id = game_id
# Launch bots for specified powers
await self.launch_all_bots(game_id, models, powers, log_level)
await self.launch_all_bots(
game_id,
models,
powers,
log_level,
negotiation_rounds=negotiation_rounds,
)
# Monitor the bots
self.monitor_bots()
@ -391,18 +408,16 @@ def parse_arguments():
parser.add_argument("--port", type=int, default=8432, help="Server port")
parser.add_argument("--username-base", default="bot", help="Base username for bots")
parser.add_argument("--password", default="password", help="Password for all bots")
parser.add_argument(
"--game-id", help="Game ID to join (creates new if not specified)"
)
parser.add_argument(
"--powers", nargs="+", help="Powers to launch bots for (default: all)"
)
parser.add_argument(
"--models", help="Comma-separated list of models in power order"
)
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
parser.add_argument("--powers", nargs="+", help="Powers to launch bots for (default: all)")
parser.add_argument("--models", help="Comma-separated list of models in power order")
parser.add_argument("--log-level", default="INFO", help="Logging level")
parser.add_argument("--creator-power", default="FRANCE", help="Power that creates the game")
parser.add_argument(
"--creator-power", default="FRANCE", help="Power that creates the game"
"--negotiation-rounds",
type=int,
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
)
return parser.parse_args()
@ -425,9 +440,7 @@ async def main():
model_list = [m.strip() for m in args.models.split(",")]
powers = args.powers or list(launcher.default_models.keys())
if len(model_list) != len(powers):
logger.error(
f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})"
)
logger.error(f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})")
return
models = dict(zip(powers, model_list))
@ -440,6 +453,7 @@ async def main():
powers=powers,
models=models,
log_level=args.log_level,
negotiation_rounds=args.negotiation_rounds,
)
else:
# Create new game and launch all bots
@ -447,6 +461,7 @@ async def main():
models=models,
log_level=args.log_level,
creator_power=args.creator_power,
negotiation_rounds=args.negotiation_rounds,
)
except KeyboardInterrupt:

View file

@ -29,6 +29,7 @@ include = [".", ".."]
[tool.ruff]
src = ["..", "."]
line-length = 150
[tool.pytest.ini_options]
testpaths = ["bot_client/tests"]
@ -44,8 +45,4 @@ markers = [
]
[dependency-groups]
dev = [
"pytest>=8.4.1",
"pytest-asyncio>=0.21.0",
"websockets>=15.0.1",
]
dev = ["pytest>=8.4.1", "pytest-asyncio>=0.21.0", "websockets>=15.0.1"]

View file

@ -13,25 +13,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import argparse
import asyncio
import signal
from typing import Optional, Dict
from typing import Optional, Dict, List
import dotenv
from loguru import logger
from websocket_diplomacy_client import (
WebSocketDiplomacyClient,
connect_to_diplomacy_server,
)
from diplomacy.utils.exceptions import DiplomacyException, GameIdException
# Suppress warnings
# os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40"
# os.environ["GRPC_VERBOSITY"] = "ERROR"
# os.environ["ABSL_MIN_LOG_LEVEL"] = "2"
# os.environ["GRPC_POLL_STRATEGY"] = "poll"
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
from diplomacy.engine.power import Power
from diplomacy.utils.exceptions import GameIdException
from diplomacy.communication.notifications import GameStatusUpdate
from diplomacy.engine.message import Message
from ai_diplomacy.clients import load_model_client
@ -40,15 +32,16 @@ from ai_diplomacy.game_history import GameHistory
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.initialization import initialize_agent_state_ext
from config import Configuration
from websocket_negotiations import (
conduct_strategic_negotiation_round,
should_participate_in_negotiations,
get_negotiation_delay,
)
dotenv.load_dotenv()
config = Configuration()
if config.DEBUG:
import tracemalloc
tracemalloc.start()
# TODO: This, but better
config: Configuration = None
class SingleBotPlayer:
@ -70,6 +63,7 @@ class SingleBotPlayer:
hostname: str = "localhost",
port: int = 8432,
game_id: Optional[str] = None,
negotiation_rounds: int = 3,
):
assert username is not None
assert password is not None
@ -87,16 +81,24 @@ class SingleBotPlayer:
# Bot state
self.client: WebSocketDiplomacyClient
self.agent: DiplomacyAgent
self.game_history = GameHistory()
self.game_history: GameHistory = GameHistory()
self.running = True
self.current_phase = None
self.waiting_for_orders = False
self.orders_submitted = False
# Negotiation settings
self.negotiation_rounds = negotiation_rounds
self.current_negotiation_round = 0
self.negotiation_complete = False
# Track error stats
self.error_stats: Dict[str, Dict[str, int]] = {
self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}
}
self.error_stats: Dict[str, Dict[str, int]] = {self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}}
# Track messaging patterns for strategic communication
self.message_counts: Dict[str, int] = {} # Messages received from each power
self.response_counts: Dict[str, int] = {} # Responses sent to each power
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
# Setup signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler)
@ -122,9 +124,7 @@ class SingleBotPlayer:
# Join or create game
if self.game_id:
logger.info(f"Joining existing game {self.game_id} as {self.power_name}")
game = await self.client.join_game(
game_id=self.game_id, power_name=self.power_name
)
await self.client.join_game(game_id=self.game_id, power_name=self.power_name)
else:
logger.info(f"Creating new game as {self.power_name}")
await self.client.create_game(
@ -142,9 +142,7 @@ class SingleBotPlayer:
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
# Initialize agent state
await initialize_agent_state_ext(
self.agent, self.client.game, self.game_history, config.log_file_path
)
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
# Setup game event callbacks
await self._setup_event_callbacks()
@ -198,6 +196,16 @@ class SingleBotPlayer:
self.current_phase = new_phase
self.game_history.add_phase(new_phase)
self.orders_submitted = False
self.current_negotiation_round = 0
self.negotiation_complete = False
# Log message statistics at phase transitions
if hasattr(self, "message_counts") and self.message_counts:
self.log_message_statistics()
# Check if we should start negotiations for movement phases (not adjustment phases)
if new_phase.endswith("M"):
await self._handle_negotiation_phase()
# Check if we need to submit orders for this new phase
await self._check_if_orders_needed()
@ -223,9 +231,7 @@ class SingleBotPlayer:
def _on_message_received(self, game, notification):
"""Handle incoming diplomatic messages."""
message = notification.message
logger.info(
f"Message received from {message.sender} to {message.recipient}: {message.message}"
)
logger.info(f"Message received from {message.sender} to {message.recipient}: {message.message}")
# Add message to game history
self.game_history.add_message(
@ -235,12 +241,17 @@ class SingleBotPlayer:
message_content=message.message,
)
# Track message patterns
if message.recipient == self.power_name and message.sender != self.power_name:
self.message_counts[message.sender] = self.message_counts.get(message.sender, 0) + 1
self._update_priority_contacts()
# If it's a private message to us, consider responding
if message.recipient == self.power_name and message.sender != self.power_name:
# Schedule the async processing in the event loop
asyncio.create_task(self._consider_message_response(message))
def _on_status_update(self, game, notification):
def _on_status_update(self, game, notification: GameStatusUpdate):
"""Handle game status changes."""
logger.info(f"Game status updated: {notification.status}")
@ -261,20 +272,50 @@ class SingleBotPlayer:
# Check if it's a phase where we can submit orders
current_short_phase = self.client.game.current_short_phase
# We submit orders in Movement and Retreat phases
if current_short_phase.endswith("M") or current_short_phase.endswith("R"):
# Check if we have units that can receive orders
orderable_locations = self.client.game.get_orderable_locations(
self.power_name
)
if orderable_locations:
logger.info(f"Orders needed for phase {current_short_phase}")
self.waiting_for_orders = True
await self._submit_orders()
else:
logger.info(
f"No orderable locations for {self.power_name} in {current_short_phase}"
)
logger.debug(f"Checking if orders needed for phase: {current_short_phase}")
# Movement and Retreat phases
orderable_locations = self.client.game.get_orderable_locations(self.power_name)
if orderable_locations:
logger.info(f"Orders needed for {current_short_phase} phase - orderable locations: {orderable_locations}")
self.waiting_for_orders = True
await self._submit_orders()
else:
logger.info(f"No orderable locations for {self.power_name} in {current_short_phase}")
async def _submit_adjustment_orders(self, action_type, count):
"""Submit build or disband orders for adjustment phase."""
# FIXME: This whole function is horse shit.
return
if action_type == "build":
# Get buildable locations for this power
possible_orders = gather_possible_order(self.client.game, self.power_name)
if not possible_orders:
logger.warning(f"No possible build orders for {self.power_name}")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
return
# Filter for build orders (usually start with unit type + location)
build_orders = [order for order in possible_orders if " - " not in order and any(order.startswith(unit) for unit in ["A ", "F "])]
logger.info(f"Available build orders for {self.power_name}: {build_orders}")
# Select up to 'count' build orders
selected_orders = build_orders[:count]
logger.info(f"Submitting build orders for {self.power_name}: {selected_orders}")
await self.client.set_orders(self.power_name, selected_orders)
elif action_type == "disband":
# Get current units for disbanding
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
await self.client.set_orders(self.power_name, disband_orders)
self.orders_submitted = True
self.waiting_for_orders = False
logger.info(f"Adjustment orders submitted successfully for {self.power_name}")
async def _submit_orders(self):
"""Generate and submit orders for the current phase."""
@ -282,108 +323,280 @@ class SingleBotPlayer:
logger.debug("Orders already submitted for this phase")
return
try:
logger.info("Generating orders...")
current_phase = self.client.game.get_current_phase()
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
# Get current board state
board_state = self.client.game.get_state()
# Get current board state
board_state = self.client.game.get_state()
# Get possible orders
possible_orders = gather_possible_orders(self.client.game, self.power_name)
# Get possible orders
possible_orders = gather_possible_orders(self.client.game, self.power_name)
if not possible_orders:
logger.info("No possible orders, submitting empty order set")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
return
# Generate orders using AI
orders = await get_valid_orders(
game=self.client.game,
client=self.agent.client,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
model_error_stats=self.error_stats,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
phase=self.client.game.get_current_phase(),
)
# Submit orders
if orders:
logger.info(f"Submitting orders: {orders}")
await self.client.set_orders(self.power_name, orders)
# Generate order diary entry
await self.agent.generate_order_diary_entry(
self.client.game,
orders,
config.log_file_path,
)
else:
logger.info("No valid orders generated, submitting empty order set")
await self.client.set_orders(self.power_name, [])
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
if not possible_orders:
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
self.waiting_for_orders = False
logger.info("Orders submitted successfully")
# Call the no wait so we don't sit around for the turns to end.
self.client.game.no_wait()
return
except DiplomacyException as e:
logger.error(f"Error submitting orders: {e}", exc_info=True)
# FIXME: I don't think we want to do this. Likely want to retry again multiple times.
#
# Submit empty orders as fallback
try:
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
except Exception as fallback_error:
logger.error(f"Failed to submit fallback orders: {fallback_error}")
# Generate orders using AI
orders = await get_valid_orders(
game=self.client.game,
client=self.agent.client,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
model_error_stats=self.error_stats,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
phase=self.client.game.get_current_phase(),
)
# Submit orders
if orders:
logger.info(f"Submitting orders: {orders}")
await self.client.set_orders(self.power_name, orders)
# Generate order diary entry
await self.agent.generate_order_diary_entry(
self.client.game,
orders,
config.log_file_path,
)
else:
logger.info("No valid orders generated, submitting empty order set")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
self.waiting_for_orders = False
logger.info("Orders submitted successfully")
# Call the no wait so we don't sit around for the turns to end.
# TODO: We probably don't want to call this here.
# We want to call it when negotiations end,
self.client.game.no_wait()
async def _analyze_phase_results(self):
"""Analyze the results of the previous phase."""
try:
logger.info("Analyzing phase results...")
logger.info("Analyzing phase results...")
# Get current board state after processing
board_state = self.client.game.get_state()
# Get current board state after processing
board_state = self.client.game.get_state()
# Generate a simple phase summary
phase_summary = f"Phase {self.current_phase} completed."
# Generate a simple phase summary
phase_summary = f"Phase {self.current_phase} completed."
# Update agent state based on results
await self.agent.analyze_phase_and_update_state(
game=self.client.game,
board_state=board_state,
phase_summary=phase_summary,
# Update agent state based on results
await self.agent.analyze_phase_and_update_state(
game=self.client.game,
board_state=board_state,
phase_summary=phase_summary,
game_history=self.game_history,
log_file_path=config.log_file_path,
)
logger.info("Phase analysis complete")
async def _handle_negotiation_phase(self):
"""Handle the negotiation phase for movement turns."""
# Check if we should participate in negotiations
if not await should_participate_in_negotiations(self.client, self.agent):
logger.info(f"{self.power_name} will not participate in negotiations this phase")
self.negotiation_complete = True
return
logger.info(f"Starting negotiation phase for {self.power_name}")
# TODO: This doesn't need a specific number of negotiation rounds, though it should have a top number of messages this turn so they don't blabber on forever.
# Conduct negotiations for the specified number of rounds
for round_num in range(1, self.negotiation_rounds + 1):
self.current_negotiation_round = round_num
logger.info(f"Negotiation round {round_num}/{self.negotiation_rounds} for {self.power_name}")
# Use strategic negotiation that analyzes recent messages
success = await conduct_strategic_negotiation_round(
client=self.client,
agent=self.agent,
game_history=self.game_history,
model_error_stats=self.error_stats,
log_file_path=config.log_file_path,
round_number=round_num,
max_rounds=self.negotiation_rounds,
)
logger.info("Phase analysis complete")
if not success:
logger.info(f"No messages sent in round {round_num} for {self.power_name}")
except Exception as e:
logger.error(f"Error analyzing phase results: {e}", exc_info=True)
# Wait between rounds to allow other bots to respond
if round_num < self.negotiation_rounds:
delay = get_negotiation_delay(round_num, self.negotiation_rounds)
logger.debug(f"Waiting {delay}s before next negotiation round")
await asyncio.sleep(delay)
self.negotiation_complete = True
logger.info(f"Negotiation phase complete for {self.power_name}")
async def _consider_message_response(self, message: Message):
"""Consider whether to respond to a diplomatic message."""
try:
# Simple logic: if someone greets us, greet back
if any(
word in message.message.lower() for word in ["hello", "hi", "greetings"]
):
response = f"Hello {message.sender}! Good to hear from you."
await self.client.game.send_game_message(
sender=self.power_name, recipient=message.sender, message=response
# Only respond to messages directed at us specifically
if message.recipient != self.power_name:
return
# Don't respond to our own messages
if message.sender == self.power_name:
return
logger.info(f"Considering response to message from {message.sender}: {message.message[:50]}...")
# Enhanced heuristic: respond to direct questions, proposals, and strategic keywords
message_lower = message.message.lower()
strategic_keywords = [
"alliance",
"deal",
"propose",
"agreement",
"support",
"attack",
"coordinate",
"move",
"order",
"help",
"work together",
"partner",
"enemy",
"threat",
"negotiate",
"discuss",
"plan",
"strategy",
"bounce",
"convoy",
"retreat",
]
should_respond = any(
[
"?" in message.message, # Questions
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
len(message.message.split()) > 15, # Longer messages suggest they want engagement
message.sender in self.priority_contacts, # Priority contacts
]
)
if should_respond:
# Generate a contextual response using AI
# Get current game state for context
board_state = self.client.get_state()
possible_orders = gather_possible_orders(self.client.game, self.power_name)
# Create a simple conversation context
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated()]
# Generate response using the agent's conversation capabilities
responses = await self.agent.client.get_conversation_reply(
game=self.client.game,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
game_phase=self.client.get_current_short_phase(),
log_file_path=config.log_file_path,
active_powers=active_powers,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
)
logger.info(f"Sent response to {message.sender}: {response}")
# Send the first response if any were generated
if responses and len(responses) > 0:
response_content = responses[0].get("content", "").strip()
if response_content:
await self.client.send_message(
sender=self.power_name,
recipient=message.sender,
message=response_content,
phase=self.client.get_current_short_phase(),
)
# Add to game history
self.game_history.add_message(
phase_name=self.client.get_current_short_phase(),
sender=self.power_name,
recipient=message.sender,
message_content=response_content,
)
# Track response patterns
self.response_counts[message.sender] = self.response_counts.get(message.sender, 0) + 1
# Add to agent's journal
self.agent.add_journal_entry(
f"Responded to {message.sender} in {self.client.get_current_short_phase()}: {response_content[:100]}..."
)
logger.info(f"Sent AI response to {message.sender}: {response_content[:50]}...")
else:
logger.debug(f"AI generated empty response to {message.sender}")
else:
logger.debug(f"AI generated no responses to {message.sender}")
else:
logger.debug(f"Decided not to respond to message from {message.sender}")
except Exception as e:
logger.error(f"Error responding to message: {e}")
logger.error(f"Error responding to message: {e}", exc_info=True)
def _update_priority_contacts(self) -> None:
"""Update the list of priority contacts based on messaging patterns."""
# Sort powers by message count (descending) and take top 3-4
sorted_contacts = sorted(self.message_counts.items(), key=lambda x: x[1], reverse=True)
# Keep top 4 most active contacts as priority
self.priority_contacts = [contact[0] for contact in sorted_contacts[:4]]
logger.debug(f"Updated priority contacts for {self.power_name}: {self.priority_contacts}")
def get_message_statistics(self) -> Dict[str, any]:
"""Get comprehensive statistics about messaging patterns."""
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated() and p_name != self.power_name]
stats = {
"power_name": self.power_name,
"total_messages_received": sum(self.message_counts.values()),
"total_responses_sent": sum(self.response_counts.values()),
"message_counts_by_power": dict(self.message_counts),
"response_counts_by_power": dict(self.response_counts),
"priority_contacts": list(self.priority_contacts),
"response_rate_by_power": {},
"active_powers": active_powers,
"current_phase": self.current_phase,
}
# Calculate response rates
for power, received in self.message_counts.items():
sent = self.response_counts.get(power, 0)
stats["response_rate_by_power"][power] = sent / received if received > 0 else 0.0
return stats
def log_message_statistics(self) -> None:
"""Log current message statistics for analysis."""
stats = self.get_message_statistics()
logger.info(f"Message Statistics for {self.power_name}:")
logger.info(f" Total messages received: {stats['total_messages_received']}")
logger.info(f" Total responses sent: {stats['total_responses_sent']}")
logger.info(f" Priority contacts: {stats['priority_contacts']}")
for power in stats["active_powers"]:
received = stats["message_counts_by_power"].get(power, 0)
sent = stats["response_counts_by_power"].get(power, 0)
rate = stats["response_rate_by_power"].get(power, 0.0)
logger.info(f" {power}: {received} received, {sent} sent, {rate:.1%} response rate")
async def run(self):
"""Main bot loop."""
@ -408,9 +621,7 @@ class SingleBotPlayer:
else:
logger.info("Bot shutting down")
except GameIdException:
logger.error(
f"Game with id {self.game_id} does not exist on the server. Exiting..."
)
logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
finally:
await self.cleanup()
@ -437,10 +648,14 @@ def parse_arguments():
parser.add_argument("--password", default="password", help="Bot password")
parser.add_argument("--power", default="FRANCE", help="Power to control")
parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use")
parser.add_argument(
"--game-id", help="Game ID to join (creates new if not specified)"
)
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
parser.add_argument("--log-level", default="INFO", help="Logging level")
parser.add_argument(
"--negotiation-rounds",
type=int,
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
)
return parser.parse_args()
@ -459,8 +674,9 @@ async def main():
power_name=args.power,
model_name=args.model,
game_id=args.game_id,
negotiation_rounds=args.negotiation_rounds,
)
config = Configuration(args.power)
await bot.run()

View file

@ -27,9 +27,7 @@ class WebSocketDiplomacyClient:
but communicates with a remote server via WebSocket connections.
"""
def __init__(
self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False
):
def __init__(self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False):
"""
Initialize the WebSocket client.
@ -199,9 +197,7 @@ class WebSocketDiplomacyClient:
return await self.channel.get_available_maps()
async def set_orders(
self, power_name: str, orders: List[str], wait: Optional[bool] = None
) -> None:
async def set_orders(self, power_name: str, orders: List[str], wait: Optional[bool] = None) -> None:
"""
Set orders for a power.
@ -246,9 +242,7 @@ class WebSocketDiplomacyClient:
else:
await self.game.no_wait()
async def send_message(
self, sender: str, recipient: str, message: str, phase: Optional[str] = None
) -> None:
async def send_message(self, sender: str, recipient: str, message: str, phase: Optional[str] = None) -> None:
"""
Send a diplomatic message.
@ -269,6 +263,57 @@ class WebSocketDiplomacyClient:
logger.debug(f"Sending message from {sender} to {recipient}: {message}")
await self.game.send_game_message(message=msg)
async def send_broadcast_message(self, sender: str, message: str, phase: Optional[str] = None) -> None:
"""
Send a broadcast message to all active powers.
Args:
sender: Sending power name
message: Message content
phase: Game phase (uses current phase if None)
"""
if not self.game:
raise DiplomacyException("Must join a game first")
if phase is None:
phase = self.game.current_short_phase
# Send to all active powers
active_powers = [power_name for power_name, power in self.powers.items() if not power.is_eliminated() and power_name != sender]
for recipient in active_powers:
await self.send_message(sender, recipient, message, phase)
logger.debug(f"Broadcast message from {sender} to {len(active_powers)} powers: {message}")
async def get_recent_messages(self, phase: Optional[str] = None, limit: int = 50) -> List[Message]:
"""
Get recent messages from the game.
Args:
phase: Filter by specific phase (None for current phase)
limit: Maximum number of messages to return
Returns:
List of recent messages
"""
if not self.game:
raise DiplomacyException("Must join a game first")
if phase is None:
phase = self.game.current_short_phase
# Get messages from the game and filter by phase
all_messages = list(self.messages.values())
# Filter by phase if specified
filtered_messages = [msg for msg in all_messages if msg.phase == phase]
# Sort by time (most recent first) and limit
filtered_messages.sort(key=lambda m: m.time_sent or 0, reverse=True)
return filtered_messages[:limit]
async def process_game(self) -> None:
"""
Force the game to process immediately (admin/moderator only).
@ -289,9 +334,7 @@ class WebSocketDiplomacyClient:
logger.debug("Synchronizing game state")
await self.game.synchronize()
async def get_phase_history(
self, from_phase: Optional[str] = None, to_phase: Optional[str] = None
) -> List[Dict[str, Any]]:
async def get_phase_history(self, from_phase: Optional[str] = None, to_phase: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Get phase history for the game.
@ -305,9 +348,7 @@ class WebSocketDiplomacyClient:
if not self.game:
raise DiplomacyException("Must join a game first")
return await self.game.get_phase_history(
from_phase=from_phase, to_phase=to_phase
)
return await self.game.get_phase_history(from_phase=from_phase, to_phase=to_phase)
async def vote(self, power_name: str, vote: str) -> None:
"""

View file

@ -0,0 +1,326 @@
"""
WebSocket-specific negotiation logic for single bot players.
This module provides negotiation capabilities for individual bots connected
via WebSocket, adapted from the multi-agent negotiation system in the main
ai_diplomacy package.
"""
from typing import Dict, Optional, List
from loguru import logger
from diplomacy.engine.message import GLOBAL
from ai_diplomacy.utils import gather_possible_orders
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.game_history import GameHistory
from websocket_diplomacy_client import WebSocketDiplomacyClient
async def conduct_single_bot_negotiation(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
max_rounds: int = 3,
round_number: int = 1,
prioritize_targets: Optional[List[str]] = None,
) -> bool:
"""
Conduct negotiation for a single bot during one negotiation round.
This function handles message generation and sending for one bot during
a negotiation phase. Unlike the multi-agent version, this focuses on
a single power and sends messages via WebSocket.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
model_error_stats: Error statistics tracking
log_file_path: Path for logging
max_rounds: Maximum number of negotiation rounds
round_number: Current round number (1-indexed)
prioritize_targets: Optional list of powers to prioritize for messaging
Returns:
True if messages were sent successfully, False otherwise
"""
power_name = agent.power_name
logger.info(f"Starting negotiation round {round_number}/{max_rounds} for {power_name}")
# Check if this power is eliminated
if client.get_power(power_name).is_eliminated():
logger.info(f"{power_name} is eliminated, skipping negotiation")
return False
# Check if this power has any orderable locations
possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders:
logger.info(f"No orderable locations for {power_name}, skipping negotiation")
return False
# Get active powers for context
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated()]
# Prioritize message targets if specified
message_targets = prioritize_targets if prioritize_targets else active_powers
message_targets = [p for p in message_targets if p in active_powers and p != power_name]
# Generate conversation messages using the AI agent
board_state = client.get_state()
messages = await agent.client.get_conversation_reply(
game=client.game,
board_state=board_state,
power_name=power_name,
possible_orders=possible_orders,
game_history=game_history,
game_phase=client.get_current_short_phase(),
log_file_path=log_file_path,
active_powers=active_powers,
agent_goals=agent.goals,
agent_relationships=agent.relationships,
agent_private_diary_str=agent.format_private_diary_for_prompt(),
)
if not messages:
logger.debug(f"No messages generated for {power_name} in round {round_number}")
return False
# Process and send each message
messages_sent = 0
for message in messages:
success = await _send_negotiation_message(client, agent, game_history, message, power_name)
if success:
messages_sent += 1
logger.info(f"Sent {messages_sent}/{len(messages)} messages for {power_name}")
return messages_sent > 0
async def _send_negotiation_message(
client: WebSocketDiplomacyClient,
agent: DiplomacyAgent,
game_history: GameHistory,
message: Dict,
power_name: str,
) -> bool:
"""
Send a single negotiation message via WebSocket.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
message: Message dictionary with content and metadata
power_name: Name of the sending power
Returns:
True if message was sent successfully, False otherwise
"""
# Validate message structure
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Invalid message format from {power_name}: {message}")
return False
content = message.get("content", "").strip()
if not content:
logger.debug(f"Empty message content from {power_name}, skipping")
return False
# Determine recipient
recipient = GLOBAL # Default to global
if message.get("message_type") == "private":
recipient = message.get("recipient", GLOBAL)
# Validate recipient is a valid power
if recipient not in client.powers and recipient != GLOBAL:
logger.warning(f"Invalid recipient '{recipient}' from {power_name}, sending globally")
recipient = GLOBAL
# Send the message via WebSocket
await client.send_message(
sender=power_name,
recipient=recipient,
message=content,
phase=client.get_current_short_phase(),
)
# Add to game history
game_history.add_message(
phase_name=client.get_current_short_phase(),
sender=power_name,
recipient=recipient,
message_content=content,
)
# Add to agent's journal
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
agent.add_journal_entry(f"Sent message {journal_recipient} in {client.get_current_short_phase()}: {content[:100]}...")
logger.info(f"[{power_name} -> {recipient}] {content[:100]}...")
return True
async def should_participate_in_negotiations(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
) -> bool:
"""
Determine if this bot should participate in negotiations.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
Returns:
True if the bot should participate in negotiations
"""
power_name = agent.power_name
# Don't negotiate if eliminated
if client.get_power(power_name).is_eliminated():
return False
# Don't negotiate if no orderable locations
possible_orders = gather_possible_orders(client.game, power_name)
if not possible_orders:
return False
# Only negotiate during movement phases
current_phase = client.get_current_short_phase()
if not current_phase.endswith("M"):
return False
return True
def get_negotiation_delay(round_number: int, total_rounds: int) -> float:
"""
Calculate delay between negotiation rounds to allow message processing.
Args:
round_number: Current round number (1-indexed)
total_rounds: Total number of rounds
Returns:
Delay in seconds
"""
# Longer delay in early rounds to allow more strategic messaging
base_delay = 10.0 # Base delay between rounds
if round_number == 1:
return base_delay * 1.5 # Extra time for first round
elif round_number == total_rounds:
return base_delay * 0.5 # Less time for final round
else:
return base_delay
async def analyze_recent_messages_for_targeting(
client: "WebSocketDiplomacyClient",
power_name: str,
max_messages: int = 20,
) -> List[str]:
"""
Analyze recent messages to identify which powers should be prioritized for negotiations.
Args:
client: WebSocket diplomacy client
power_name: Name of the analyzing power
max_messages: Maximum number of recent messages to analyze
Returns:
List of power names in order of priority for messaging
"""
# Get recent messages from current phase
recent_messages = await client.get_recent_messages(limit=max_messages)
# Track who has been active and who has messaged us
message_activity = {}
direct_messages_to_us = {}
for message in recent_messages:
sender = message.sender
recipient = message.recipient
# Track general activity
if sender != power_name:
message_activity[sender] = message_activity.get(sender, 0) + 1
# Track direct messages to us
if recipient == power_name and sender != power_name:
direct_messages_to_us[sender] = direct_messages_to_us.get(sender, 0) + 1
# Get all active powers
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated() and p_name != power_name]
# Prioritize based on: 1) Powers that messaged us directly, 2) Most active powers
priority_list = []
# First, add powers that sent us direct messages (sorted by count)
direct_senders = sorted(direct_messages_to_us.items(), key=lambda x: x[1], reverse=True)
for sender, _ in direct_senders:
if sender in active_powers:
priority_list.append(sender)
# Then add other active powers (sorted by activity)
remaining_powers = [p for p in active_powers if p not in priority_list]
activity_sorted = sorted(remaining_powers, key=lambda p: message_activity.get(p, 0), reverse=True)
priority_list.extend(activity_sorted)
logger.debug(f"Message targeting priority for {power_name}: {priority_list}")
return priority_list
async def conduct_strategic_negotiation_round(
client: "WebSocketDiplomacyClient",
agent: "DiplomacyAgent",
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
round_number: int,
max_rounds: int = 3,
) -> bool:
"""
Conduct a single negotiation round with strategic message targeting.
This function analyzes recent message activity to determine which powers
to prioritize for messaging in this round.
Args:
client: WebSocket diplomacy client
agent: The bot's AI agent
game_history: Game history tracker
model_error_stats: Error statistics tracking
log_file_path: Path for logging
round_number: Current round number (1-indexed)
max_rounds: Maximum number of negotiation rounds
Returns:
True if messages were sent successfully, False otherwise
"""
power_name = agent.power_name
# Analyze recent messages to prioritize targets
priority_targets = await analyze_recent_messages_for_targeting(client, power_name)
# Limit to top 3-4 targets in later rounds to focus conversations
if round_number > 1:
priority_targets = priority_targets[: min(4, len(priority_targets))]
logger.info(f"Round {round_number} targets for {power_name}: {priority_targets}")
# Conduct negotiation with prioritized targets
return await conduct_single_bot_negotiation(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=model_error_stats,
log_file_path=log_file_path,
max_rounds=max_rounds,
round_number=round_number,
prioritize_targets=priority_targets,
)