mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
WIP: Exposing errors that cause bot to fail
Working to remove all the lower level try except statements that catch barely, and don't offer any information to top level about what went wrong. Better to throw exceptions and let the consumers catch if needed.
This commit is contained in:
parent
f37c344213
commit
c09d5bfe7e
13 changed files with 1013 additions and 400 deletions
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
|
|
@ -105,7 +105,7 @@
|
|||
"--game-id",
|
||||
"${input:gameId}",
|
||||
"--log-level",
|
||||
"INFO"
|
||||
"DEBUG"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
|
|
@ -138,6 +138,7 @@
|
|||
"type": "pickString",
|
||||
"description": "Select the AI model to use",
|
||||
"options": [
|
||||
"gemini-2.5-flash-lite-preview-06-17",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"claude-3-haiku",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import re
|
|||
import json_repair
|
||||
import json5
|
||||
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
from diplomacy.engine.game import Game # More forgiving JSON parser
|
||||
|
||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||
|
|
@ -1150,6 +1151,9 @@ class DiplomacyAgent:
|
|||
game_history: "GameHistory",
|
||||
log_file_path: str | Path,
|
||||
):
|
||||
assert GameHistory is not None, (
|
||||
"GameHistory object must be initialized before calling analyze_phase_and_update_state"
|
||||
)
|
||||
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
|
||||
# Use self.power_name internally
|
||||
power_name = self.power_name
|
||||
|
|
@ -1169,7 +1173,7 @@ class DiplomacyAgent:
|
|||
return
|
||||
|
||||
# Get previous phase safely from history
|
||||
if not game_history or not game_history.phases:
|
||||
if not game_history.phases:
|
||||
logger.warning(
|
||||
f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update."
|
||||
)
|
||||
|
|
@ -1254,9 +1258,8 @@ class DiplomacyAgent:
|
|||
log_entry_success = "FALSE" # Default
|
||||
update_data = None # Initialize
|
||||
|
||||
if (
|
||||
response is not None and response.strip()
|
||||
): # Check if response is not None and not just whitespace
|
||||
# Check if response is not None and not just whitespace
|
||||
if response is not None and response.strip():
|
||||
try:
|
||||
update_data = self._extract_json_from_text(response)
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from anthropic import AsyncAnthropic
|
|||
|
||||
import google.generativeai as genai
|
||||
|
||||
from diplomacy.engine.message import GLOBAL
|
||||
from .game_history import GameHistory
|
||||
from .utils import (
|
||||
load_prompt,
|
||||
|
|
@ -25,7 +24,6 @@ from .utils import (
|
|||
|
||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||
# from .agent import DiplomacyAgent
|
||||
from .possible_order_context import generate_rich_order_context
|
||||
from .prompt_constructor import (
|
||||
construct_order_generation_prompt,
|
||||
build_context_prompt,
|
||||
|
|
@ -113,7 +111,6 @@ class BaseModelClient:
|
|||
raw_response = await run_llm_and_log(
|
||||
client=self,
|
||||
prompt=prompt,
|
||||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=phase,
|
||||
response_type="order", # Context for run_llm_and_log's own error logging
|
||||
|
|
@ -553,7 +550,6 @@ class BaseModelClient:
|
|||
raw_response = await run_llm_and_log(
|
||||
client=self,
|
||||
prompt=prompt,
|
||||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game_phase, # Use game_phase for logging
|
||||
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
||||
|
|
@ -604,7 +600,6 @@ class BaseModelClient:
|
|||
raw_response = await run_llm_and_log(
|
||||
client=self,
|
||||
prompt=raw_input_prompt,
|
||||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game_phase,
|
||||
response_type="negotiation", # For run_llm_and_log's internal context
|
||||
|
|
@ -829,7 +824,6 @@ class BaseModelClient:
|
|||
raw_plan_response = await run_llm_and_log(
|
||||
client=self, # Pass self (the client instance)
|
||||
prompt=full_prompt,
|
||||
log_file_path=log_file_path,
|
||||
power_name=power_name,
|
||||
phase=game.current_short_phase,
|
||||
response_type="plan_generation", # More specific type for run_llm_and_log context
|
||||
|
|
@ -1328,4 +1322,3 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
|||
):
|
||||
visible.append(msg)
|
||||
return visible # already in chronological order if appended that way
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .utils import run_llm_and_log, log_llm_response, load_prompt
|
||||
from .utils import run_llm_and_log, log_llm_response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diplomacy import Game
|
||||
|
|
@ -12,6 +12,25 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||
import os
|
||||
|
||||
try:
|
||||
if prompts_dir:
|
||||
filepath = os.path.join(prompts_dir, filename)
|
||||
else:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||
filepath = os.path.join(default_prompts_dir, filename)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_diary_consolidation(
|
||||
agent: "DiplomacyAgent",
|
||||
game: "Game",
|
||||
|
|
@ -23,19 +42,32 @@ async def run_diary_consolidation(
|
|||
Consolidate older diary entries while keeping recent ones.
|
||||
This is the logic moved from the DiplomacyAgent class.
|
||||
"""
|
||||
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
||||
logger.info(
|
||||
f"[{agent.power_name}] CONSOLIDATION START — "
|
||||
f"{len(agent.full_private_diary)} total full entries"
|
||||
)
|
||||
|
||||
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
||||
full_entries = [
|
||||
e
|
||||
for e in agent.full_private_diary
|
||||
if not e.startswith("[CONSOLIDATED HISTORY]")
|
||||
]
|
||||
|
||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
||||
logger.info(
|
||||
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
||||
"skipping consolidation"
|
||||
)
|
||||
return
|
||||
|
||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||
if not match:
|
||||
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
||||
"aborting consolidation"
|
||||
)
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
return
|
||||
|
||||
|
|
@ -46,19 +78,37 @@ async def run_diary_consolidation(
|
|||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
||||
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
||||
entries_to_summarize = [
|
||||
e
|
||||
for e in full_entries
|
||||
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
||||
]
|
||||
entries_to_keep = [
|
||||
e
|
||||
for e in full_entries
|
||||
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
||||
]
|
||||
|
||||
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
||||
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
||||
)
|
||||
|
||||
if not entries_to_summarize:
|
||||
agent.private_diary = list(agent.full_private_diary)
|
||||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||
logger.warning(
|
||||
f"[{agent.power_name}] No eligible entries to summarise; "
|
||||
"context diary left unchanged"
|
||||
)
|
||||
return
|
||||
|
||||
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||
prompt_template = _load_prompt_file(
|
||||
"diary_consolidation_prompt.txt", prompts_dir=prompts_dir
|
||||
)
|
||||
if not prompt_template:
|
||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||
logger.error(
|
||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||
)
|
||||
return
|
||||
|
||||
prompt = prompt_template.format(
|
||||
|
|
@ -87,14 +137,23 @@ async def run_diary_consolidation(
|
|||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||
success_flag = "TRUE"
|
||||
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
||||
logger.info(
|
||||
f"[{agent.power_name}] Consolidation complete — "
|
||||
f"{len(agent.private_diary)} context entries now"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
||||
logger.error(
|
||||
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
||||
)
|
||||
finally:
|
||||
log_llm_response(
|
||||
log_file_path=log_file_path,
|
||||
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
||||
model_name=(
|
||||
consolidation_client.model_name
|
||||
if consolidation_client is not None
|
||||
else agent.client.model_name
|
||||
),
|
||||
power_name=agent.power_name,
|
||||
phase=game.current_short_phase,
|
||||
response_type="diary_consolidation",
|
||||
|
|
@ -102,3 +161,4 @@ async def run_diary_consolidation(
|
|||
raw_response=raw_response,
|
||||
success=success_flag,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -148,7 +148,6 @@ class GameHistory:
|
|||
return {}
|
||||
return self.phases[-1].plans
|
||||
|
||||
# NEW METHOD
|
||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||
current_phase: Optional[Phase] = None
|
||||
for phase_obj in self.phases:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from os.path import exists
|
||||
import datetime
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
|
@ -6,16 +6,16 @@ import warnings
|
|||
|
||||
class Configuration(BaseSettings):
|
||||
DEBUG: bool = False
|
||||
log_file_path: Path = Path("./logs/logs.txt")
|
||||
log_file_path: Path
|
||||
DEEPSEEK_API_KEY: str | None = None
|
||||
OPENAI_API_KEY: str | None = None
|
||||
ANTHROPIC_API_KEY: str | None = None
|
||||
GEMINI_API_KEY: str | None = None
|
||||
OPENROUTER_API_KEY: str | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, power_name, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
|
||||
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||
self.log_file_path = self.log_file_path.resolve()
|
||||
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -42,11 +42,7 @@ class Configuration(BaseSettings):
|
|||
"""Override to check for empty API keys at access time"""
|
||||
value = super().__getattribute__(name)
|
||||
|
||||
if name.endswith("_KEY") and (
|
||||
not value or (isinstance(value, str) and len(value) == 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"API key '{name}' is not set or is empty. Please configure it before use."
|
||||
)
|
||||
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -30,12 +30,9 @@ from websocket_diplomacy_client import (
|
|||
from models import (
|
||||
SignInRequest,
|
||||
CreateGameRequest,
|
||||
JoinGameRequest,
|
||||
SetOrdersRequest,
|
||||
ListGamesRequest,
|
||||
ProcessGameRequest,
|
||||
serialize_message,
|
||||
parse_message,
|
||||
)
|
||||
|
||||
import sys
|
||||
|
|
@ -64,9 +61,7 @@ logging.getLogger("root").setLevel(logging.WARNING)
|
|||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a Diplomacy game simulation with WebSocket server connection."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Run a Diplomacy game simulation with WebSocket server connection.")
|
||||
parser.add_argument(
|
||||
"--hostname",
|
||||
type=str,
|
||||
|
|
@ -137,9 +132,7 @@ def parse_arguments():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
async def join_powers_for_testing(
|
||||
client: WebSocketDiplomacyClient, power_model_map: dict
|
||||
):
|
||||
async def join_powers_for_testing(client: WebSocketDiplomacyClient, power_model_map: dict):
|
||||
"""
|
||||
Join multiple powers in the same game for testing purposes.
|
||||
This simulates having multiple AI players in one game.
|
||||
|
|
@ -150,17 +143,13 @@ async def join_powers_for_testing(
|
|||
for power_name in power_names[1:]:
|
||||
try:
|
||||
logger.info(f"Attempting to join power {power_name}")
|
||||
await client.channel.join_game(
|
||||
game_id=client.game_id, power_name=power_name
|
||||
)
|
||||
await client.channel.join_game(game_id=client.game_id, power_name=power_name)
|
||||
logger.info(f"Successfully joined {power_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not join {power_name}: {e}")
|
||||
|
||||
|
||||
async def create_or_join_game(
|
||||
client: WebSocketDiplomacyClient, args, power_model_map: dict
|
||||
):
|
||||
async def create_or_join_game(client: WebSocketDiplomacyClient, args, power_model_map: dict):
|
||||
"""
|
||||
Create a new game or join an existing one based on arguments.
|
||||
"""
|
||||
|
|
@ -171,9 +160,7 @@ async def create_or_join_game(
|
|||
# List available games first to see what's available
|
||||
try:
|
||||
games = await client.list_games()
|
||||
logger.info(
|
||||
f"Available games: {[g.get('game_id', 'unknown') for g in games]}"
|
||||
)
|
||||
logger.info(f"Available games: {[g.get('game_id', 'unknown') for g in games]}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list games: {e}")
|
||||
|
||||
|
|
@ -189,19 +176,13 @@ async def create_or_join_game(
|
|||
logger.info("Creating new game")
|
||||
|
||||
# Get the first power to control
|
||||
first_power = (
|
||||
list(power_model_map.keys())[0]
|
||||
if not args.create_multi_power_game
|
||||
else None
|
||||
)
|
||||
first_power = list(power_model_map.keys())[0] if not args.create_multi_power_game else None
|
||||
|
||||
game = await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name=first_power,
|
||||
n_controls=7
|
||||
if not args.create_multi_power_game
|
||||
else 1, # Lower requirement for testing
|
||||
n_controls=7 if not args.create_multi_power_game else 1, # Lower requirement for testing
|
||||
deadline=None, # No time pressure for AI testing
|
||||
)
|
||||
|
||||
|
|
@ -215,21 +196,17 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
|
|||
"""
|
||||
Demonstrate how to use typed WebSocket messages.
|
||||
This shows the proper way to construct messages according to the protocol.
|
||||
|
||||
|
||||
NOTE: This is for demonstration only - the current client implementation
|
||||
uses the diplomacy library's built-in networking which abstracts the raw protocol.
|
||||
"""
|
||||
|
||||
|
||||
logger.info("=== Typed Message Examples ===")
|
||||
|
||||
|
||||
# Example 1: Authentication Request
|
||||
sign_in_request = SignInRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
username=args.username,
|
||||
password=args.password
|
||||
)
|
||||
sign_in_request = SignInRequest(request_id=str(uuid.uuid4()), username=args.username, password=args.password)
|
||||
logger.info(f"SignIn request: {serialize_message(sign_in_request)}")
|
||||
|
||||
|
||||
# Example 2: Create Game Request
|
||||
create_game_request = CreateGameRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
|
|
@ -237,29 +214,25 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
|
|||
map_name="standard",
|
||||
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
n_controls=1,
|
||||
power_name="FRANCE"
|
||||
power_name="FRANCE",
|
||||
)
|
||||
logger.info(f"CreateGame request: {serialize_message(create_game_request)}")
|
||||
|
||||
|
||||
# Example 3: List Games Request
|
||||
list_games_request = ListGamesRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token="example_auth_token_here",
|
||||
include_protected=False
|
||||
)
|
||||
list_games_request = ListGamesRequest(request_id=str(uuid.uuid4()), token="example_auth_token_here", include_protected=False)
|
||||
logger.info(f"ListGames request: {serialize_message(list_games_request)}")
|
||||
|
||||
|
||||
# Example 4: Set Orders Request (would be used in game loop)
|
||||
if hasattr(client, 'game_id') and hasattr(client, 'token'):
|
||||
if hasattr(client, "game_id") and hasattr(client, "token"):
|
||||
set_orders_request = SetOrdersRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token=client.token or "example_token",
|
||||
game_id=client.game_id or "example_game_id",
|
||||
game_role="FRANCE",
|
||||
orders=["A PAR H", "F BRE H", "A MAR H"]
|
||||
orders=["A PAR H", "F BRE H", "A MAR H"],
|
||||
)
|
||||
logger.info(f"SetOrders request: {serialize_message(set_orders_request)}")
|
||||
|
||||
|
||||
logger.info("=== End Typed Message Examples ===")
|
||||
|
||||
|
||||
|
|
@ -282,9 +255,7 @@ async def main():
|
|||
logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs")
|
||||
start_whole = time.time()
|
||||
|
||||
model_error_stats = defaultdict(
|
||||
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
|
||||
)
|
||||
model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
|
||||
|
||||
# Determine the result folder based on a timestamp
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
|
|
@ -306,9 +277,7 @@ async def main():
|
|||
|
||||
# File paths
|
||||
manifesto_path = f"{result_folder}/game_manifesto.txt"
|
||||
game_file_path = (
|
||||
args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
|
||||
)
|
||||
game_file_path = args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
|
||||
overview_file_path = f"{result_folder}/overview.jsonl"
|
||||
llm_log_file_path = f"{result_folder}/llm_responses.csv"
|
||||
|
||||
|
|
@ -325,9 +294,7 @@ async def main():
|
|||
]
|
||||
provided_models = [name.strip() for name in args.models.split(",")]
|
||||
if len(provided_models) != len(powers_order):
|
||||
logger.error(
|
||||
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting."
|
||||
)
|
||||
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting.")
|
||||
return
|
||||
power_model_map = dict(zip(powers_order, provided_models))
|
||||
else:
|
||||
|
|
@ -377,32 +344,20 @@ async def main():
|
|||
client_obj = load_model_client(model_id)
|
||||
agent = DiplomacyAgent(power_name=power_name, client=client_obj)
|
||||
agents[power_name] = agent
|
||||
logger.info(
|
||||
f"Preparing initialization task for {power_name} with model {model_id}"
|
||||
)
|
||||
initialization_tasks.append(
|
||||
initialize_agent_state_ext(
|
||||
agent, client.game, game_history, llm_log_file_path
|
||||
)
|
||||
)
|
||||
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||
initialization_tasks.append(initialize_agent_state_ext(agent, client.game, game_history, llm_log_file_path))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create agent for {power_name} with model {model_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipping agent initialization for {power_name} (no model or eliminated)"
|
||||
)
|
||||
logger.info(f"Skipping agent initialization for {power_name} (no model or eliminated)")
|
||||
|
||||
# Run initializations concurrently
|
||||
if initialization_tasks:
|
||||
logger.info(
|
||||
f"Running {len(initialization_tasks)} agent initializations concurrently..."
|
||||
)
|
||||
initialization_results = await asyncio.gather(
|
||||
*initialization_tasks, return_exceptions=True
|
||||
)
|
||||
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
||||
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
||||
|
||||
initialized_powers = list(agents.keys())
|
||||
for i, result in enumerate(initialization_results):
|
||||
|
|
@ -414,9 +369,7 @@ async def main():
|
|||
exc_info=result,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Successfully initialized agent state for {power_name}."
|
||||
)
|
||||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||
|
||||
# Main game loop
|
||||
all_phase_relationships = {}
|
||||
|
|
@ -433,9 +386,7 @@ async def main():
|
|||
game_history.add_phase(current_phase)
|
||||
current_short_phase = client.get_current_short_phase()
|
||||
|
||||
logger.info(
|
||||
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
|
||||
)
|
||||
logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
|
||||
|
||||
# Prevent unbounded simulation based on year
|
||||
year_str = current_phase[1:5]
|
||||
|
|
@ -447,9 +398,7 @@ async def main():
|
|||
# Negotiations for movement phases
|
||||
if client.get_current_short_phase().endswith("M"):
|
||||
if args.num_negotiation_rounds > 0:
|
||||
logger.info(
|
||||
f"Running {args.num_negotiation_rounds} rounds of negotiations..."
|
||||
)
|
||||
logger.info(f"Running {args.num_negotiation_rounds} rounds of negotiations...")
|
||||
game_history = await conduct_negotiations(
|
||||
client.game, # Pass the NetworkGame object
|
||||
agents,
|
||||
|
|
@ -459,9 +408,7 @@ async def main():
|
|||
log_file_path=llm_log_file_path,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping negotiation phase as num_negotiation_rounds=0"
|
||||
)
|
||||
logger.info("Skipping negotiation phase as num_negotiation_rounds=0")
|
||||
|
||||
# Planning phase (if enabled)
|
||||
if args.planning_phase:
|
||||
|
|
@ -475,29 +422,19 @@ async def main():
|
|||
)
|
||||
|
||||
# Generate negotiation diary entries
|
||||
logger.info(
|
||||
f"Generating negotiation diary entries for phase {current_short_phase}..."
|
||||
)
|
||||
active_powers_for_neg_diary = [
|
||||
p for p in agents.keys() if not client.get_power(p).is_eliminated()
|
||||
]
|
||||
logger.info(f"Generating negotiation diary entries for phase {current_short_phase}...")
|
||||
active_powers_for_neg_diary = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
|
||||
|
||||
neg_diary_tasks = []
|
||||
for power_name, agent in agents.items():
|
||||
if not client.get_power(power_name).is_eliminated():
|
||||
neg_diary_tasks.append(
|
||||
agent.generate_negotiation_diary_entry(
|
||||
client.game, game_history, llm_log_file_path
|
||||
)
|
||||
)
|
||||
neg_diary_tasks.append(agent.generate_negotiation_diary_entry(client.game, game_history, llm_log_file_path))
|
||||
if neg_diary_tasks:
|
||||
await asyncio.gather(*neg_diary_tasks, return_exceptions=True)
|
||||
|
||||
# AI Decision Making: Get orders for each controlled power
|
||||
logger.info("Getting orders from agents...")
|
||||
active_powers_for_orders = [
|
||||
p for p in agents.keys() if not client.get_power(p).is_eliminated()
|
||||
]
|
||||
active_powers_for_orders = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
|
||||
|
||||
order_tasks = []
|
||||
order_power_names = []
|
||||
|
|
@ -505,35 +442,23 @@ async def main():
|
|||
|
||||
for power_name, agent in agents.items():
|
||||
if client.get_power(power_name).is_eliminated():
|
||||
logger.debug(
|
||||
f"Skipping order generation for eliminated power {power_name}."
|
||||
)
|
||||
logger.debug(f"Skipping order generation for eliminated power {power_name}.")
|
||||
continue
|
||||
|
||||
# Diagnostic logging
|
||||
logger.info(
|
||||
f"--- Diagnostic Log for {power_name} in phase {current_phase} ---"
|
||||
)
|
||||
logger.info(f"--- Diagnostic Log for {power_name} in phase {current_phase} ---")
|
||||
try:
|
||||
orderable_locs = client.get_orderable_locations(power_name)
|
||||
logger.info(
|
||||
f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}"
|
||||
)
|
||||
logger.info(f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}")
|
||||
actual_units = client.get_units(power_name)
|
||||
logger.info(
|
||||
f"[{power_name}][{current_phase}] Actual units: {actual_units}"
|
||||
)
|
||||
logger.info(f"[{power_name}][{current_phase}] Actual units: {actual_units}")
|
||||
except Exception as e_diag:
|
||||
logger.error(
|
||||
f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}"
|
||||
)
|
||||
logger.error(f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}")
|
||||
|
||||
# Calculate possible orders
|
||||
possible_orders = gather_possible_orders(client.game, power_name)
|
||||
if not possible_orders:
|
||||
logger.debug(
|
||||
f"No orderable locations for {power_name}; submitting empty orders."
|
||||
)
|
||||
logger.debug(f"No orderable locations for {power_name}; submitting empty orders.")
|
||||
await client.set_orders(power_name, [])
|
||||
continue
|
||||
|
||||
|
|
@ -559,12 +484,8 @@ async def main():
|
|||
|
||||
# Run order generation concurrently
|
||||
if order_tasks:
|
||||
logger.debug(
|
||||
f"Running {len(order_tasks)} order generation tasks concurrently..."
|
||||
)
|
||||
order_results = await asyncio.gather(
|
||||
*order_tasks, return_exceptions=True
|
||||
)
|
||||
logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...")
|
||||
order_results = await asyncio.gather(*order_tasks, return_exceptions=True)
|
||||
else:
|
||||
order_results = []
|
||||
|
||||
|
|
@ -580,9 +501,7 @@ async def main():
|
|||
)
|
||||
await client.set_orders(p_name, [])
|
||||
elif result is None:
|
||||
logger.warning(
|
||||
f"get_valid_orders returned None for {p_name}. Setting empty orders."
|
||||
)
|
||||
logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.")
|
||||
await client.set_orders(p_name, [])
|
||||
else:
|
||||
orders = result
|
||||
|
|
@ -598,17 +517,13 @@ async def main():
|
|||
# orders=orders
|
||||
# )
|
||||
# await client.send_typed_message(set_orders_msg)
|
||||
|
||||
|
||||
await client.set_orders(p_name, orders)
|
||||
logger.debug(
|
||||
f"Set orders for {p_name} in {current_short_phase}: {orders}"
|
||||
)
|
||||
logger.debug(f"Set orders for {p_name} in {current_short_phase}: {orders}")
|
||||
|
||||
# Generate order diary entry
|
||||
try:
|
||||
await agent.generate_order_diary_entry(
|
||||
client.game, orders, llm_log_file_path
|
||||
)
|
||||
await agent.generate_order_diary_entry(client.game, orders, llm_log_file_path)
|
||||
except Exception as e_diary:
|
||||
logger.error(
|
||||
f"Error generating order diary for {p_name}: {e_diary}",
|
||||
|
|
@ -627,7 +542,7 @@ async def main():
|
|||
# phase=current_phase
|
||||
# )
|
||||
# await client.send_typed_message(process_msg)
|
||||
|
||||
|
||||
logger.info(f"Processing orders for {current_phase}...")
|
||||
await simulate_game_processing(client)
|
||||
|
||||
|
|
@ -650,22 +565,13 @@ async def main():
|
|||
# Collect relationships for this phase
|
||||
current_relationships_for_phase = {}
|
||||
for power_name, agent in agents.items():
|
||||
if (
|
||||
power_name in client.powers
|
||||
and not client.get_power(power_name).is_eliminated()
|
||||
):
|
||||
if power_name in client.powers and not client.get_power(power_name).is_eliminated():
|
||||
current_relationships_for_phase[power_name] = agent.relationships
|
||||
all_phase_relationships[current_short_phase] = (
|
||||
current_relationships_for_phase
|
||||
)
|
||||
all_phase_relationships[current_short_phase] = current_relationships_for_phase
|
||||
|
||||
# Generate phase result diary entries
|
||||
logger.info(
|
||||
f"Generating phase result diary entries for completed phase {current_phase}..."
|
||||
)
|
||||
phase_summary = getattr(client.game, "phase_summaries", {}).get(
|
||||
current_phase, "(Summary not generated)"
|
||||
)
|
||||
logger.info(f"Generating phase result diary entries for completed phase {current_phase}...")
|
||||
phase_summary = getattr(client.game, "phase_summaries", {}).get(current_phase, "(Summary not generated)")
|
||||
all_orders_this_phase = current_order_history
|
||||
|
||||
phase_result_diary_tasks = []
|
||||
|
|
@ -685,16 +591,10 @@ async def main():
|
|||
await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True)
|
||||
|
||||
# State update analysis
|
||||
logger.info(
|
||||
f"Starting state update analysis for completed phase {current_phase}..."
|
||||
)
|
||||
logger.info(f"Starting state update analysis for completed phase {current_phase}...")
|
||||
current_board_state = client.get_state()
|
||||
|
||||
active_agent_powers = [
|
||||
(p, power)
|
||||
for p, power in client.powers.items()
|
||||
if p in agents and not power.is_eliminated()
|
||||
]
|
||||
active_agent_powers = [(p, power) for p, power in client.powers.items() if p in agents and not power.is_eliminated()]
|
||||
|
||||
if active_agent_powers:
|
||||
state_update_tasks = []
|
||||
|
|
@ -781,4 +681,3 @@ async def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
|
|||
|
|
@ -8,34 +8,39 @@ and notifications (server -> client).
|
|||
|
||||
from typing import Optional, List, Dict, Any, Union, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base Message Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Base class for all WebSocket messages."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class BaseRequest(BaseMessage):
|
||||
"""Base class for all client -> server requests."""
|
||||
|
||||
request_id: str
|
||||
re_sent: bool = False
|
||||
|
||||
|
||||
class BaseResponse(BaseMessage):
|
||||
"""Base class for all server -> client responses."""
|
||||
|
||||
request_id: str
|
||||
|
||||
|
||||
class BaseNotification(BaseMessage):
|
||||
"""Base class for all server -> client notifications (no request_id)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -43,8 +48,10 @@ class BaseNotification(BaseMessage):
|
|||
# Authentication & Connection Level Messages
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SignInRequest(BaseRequest):
|
||||
"""Client authentication request."""
|
||||
|
||||
name: Literal["sign_in"] = "sign_in"
|
||||
username: str
|
||||
password: str
|
||||
|
|
@ -52,6 +59,7 @@ class SignInRequest(BaseRequest):
|
|||
|
||||
class GetDaidePortRequest(BaseRequest):
|
||||
"""Request DAIDE TCP port for a game."""
|
||||
|
||||
name: Literal["get_daide_port"] = "get_daide_port"
|
||||
game_id: str
|
||||
|
||||
|
|
@ -60,13 +68,16 @@ class GetDaidePortRequest(BaseRequest):
|
|||
# Channel Level Messages (require authentication token)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ChannelRequest(BaseRequest):
|
||||
"""Base class for channel-level requests that require authentication."""
|
||||
|
||||
token: str
|
||||
|
||||
|
||||
class CreateGameRequest(ChannelRequest):
|
||||
"""Create a new game."""
|
||||
|
||||
name: Literal["create_game"] = "create_game"
|
||||
map_name: str = "standard"
|
||||
rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"])
|
||||
|
|
@ -78,6 +89,7 @@ class CreateGameRequest(ChannelRequest):
|
|||
|
||||
class JoinGameRequest(ChannelRequest):
|
||||
"""Join an existing game."""
|
||||
|
||||
name: Literal["join_game"] = "join_game"
|
||||
game_id: str
|
||||
power_name: Optional[str] = None
|
||||
|
|
@ -86,6 +98,7 @@ class JoinGameRequest(ChannelRequest):
|
|||
|
||||
class JoinPowersRequest(ChannelRequest):
|
||||
"""Join multiple powers in a game."""
|
||||
|
||||
name: Literal["join_powers"] = "join_powers"
|
||||
game_id: str
|
||||
power_names: List[str]
|
||||
|
|
@ -94,32 +107,37 @@ class JoinPowersRequest(ChannelRequest):
|
|||
|
||||
class ListGamesRequest(ChannelRequest):
|
||||
"""List available games."""
|
||||
|
||||
name: Literal["list_games"] = "list_games"
|
||||
game_id_filter: Optional[str] = None
|
||||
map_name: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
include_protected: bool = False
|
||||
|
||||
|
||||
class GetPlayablePowersRequest(ChannelRequest):
|
||||
"""Get uncontrolled powers in a game."""
|
||||
|
||||
name: Literal["get_playable_powers"] = "get_playable_powers"
|
||||
game_id: str
|
||||
|
||||
|
||||
class GetAvailableMapsRequest(ChannelRequest):
|
||||
"""Get list of available maps."""
|
||||
|
||||
name: Literal["get_available_maps"] = "get_available_maps"
|
||||
|
||||
|
||||
class GetDummyWaitingPowersRequest(ChannelRequest):
|
||||
"""Get AI-controllable powers (bot use)."""
|
||||
|
||||
name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers"
|
||||
game_id: str
|
||||
|
||||
|
||||
class SetGradeRequest(ChannelRequest):
|
||||
"""Modify user permissions."""
|
||||
|
||||
name: Literal["set_grade"] = "set_grade"
|
||||
username: str
|
||||
grade: str
|
||||
|
|
@ -127,11 +145,13 @@ class SetGradeRequest(ChannelRequest):
|
|||
|
||||
class DeleteAccountRequest(ChannelRequest):
|
||||
"""Delete user account."""
|
||||
|
||||
name: Literal["delete_account"] = "delete_account"
|
||||
|
||||
|
||||
class LogoutRequest(ChannelRequest):
|
||||
"""Disconnect from server."""
|
||||
|
||||
name: Literal["logout"] = "logout"
|
||||
|
||||
|
||||
|
|
@ -139,8 +159,10 @@ class LogoutRequest(ChannelRequest):
|
|||
# Game Level Messages (require authentication + game context)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GameRequest(ChannelRequest):
|
||||
"""Base class for game-level requests."""
|
||||
|
||||
game_id: str
|
||||
game_role: str # Power name like "ENGLAND"
|
||||
phase: Optional[str] = None
|
||||
|
|
@ -148,18 +170,21 @@ class GameRequest(ChannelRequest):
|
|||
|
||||
class SetOrdersRequest(GameRequest):
|
||||
"""Submit orders for a power."""
|
||||
|
||||
name: Literal["set_orders"] = "set_orders"
|
||||
orders: List[str]
|
||||
|
||||
|
||||
class SetWaitFlagRequest(GameRequest):
|
||||
"""Set wait flag for turn processing."""
|
||||
|
||||
name: Literal["set_wait_flag"] = "set_wait_flag"
|
||||
wait: bool
|
||||
|
||||
|
||||
class SendGameMessageRequest(GameRequest):
|
||||
"""Send diplomatic message."""
|
||||
|
||||
name: Literal["send_game_message"] = "send_game_message"
|
||||
recipient: str # Power name or "GLOBAL"
|
||||
message: str
|
||||
|
|
@ -168,11 +193,13 @@ class SendGameMessageRequest(GameRequest):
|
|||
|
||||
class GetAllPossibleOrdersRequest(GameRequest):
|
||||
"""Get legal orders for current phase."""
|
||||
|
||||
name: Literal["get_all_possible_orders"] = "get_all_possible_orders"
|
||||
|
||||
|
||||
class GetPhaseHistoryRequest(GameRequest):
|
||||
"""Get historical game phases."""
|
||||
|
||||
name: Literal["get_phase_history"] = "get_phase_history"
|
||||
from_phase: Optional[str] = None
|
||||
to_phase: Optional[str] = None
|
||||
|
|
@ -180,45 +207,53 @@ class GetPhaseHistoryRequest(GameRequest):
|
|||
|
||||
class ProcessGameRequest(GameRequest):
|
||||
"""Force game processing (master only)."""
|
||||
|
||||
name: Literal["process_game"] = "process_game"
|
||||
|
||||
|
||||
class VoteRequest(GameRequest):
|
||||
"""Vote for/against draw."""
|
||||
|
||||
name: Literal["vote"] = "vote"
|
||||
vote: Literal["yes", "no"]
|
||||
|
||||
|
||||
class SaveGameRequest(GameRequest):
|
||||
"""Export game as JSON."""
|
||||
|
||||
name: Literal["save_game"] = "save_game"
|
||||
|
||||
|
||||
class SetGameStateRequest(GameRequest):
|
||||
"""Modify game state (master only)."""
|
||||
|
||||
name: Literal["set_game_state"] = "set_game_state"
|
||||
state: Dict[str, Any]
|
||||
|
||||
|
||||
class SetGameStatusRequest(GameRequest):
|
||||
"""Change game status (master only)."""
|
||||
|
||||
name: Literal["set_game_status"] = "set_game_status"
|
||||
status: str
|
||||
|
||||
|
||||
class SetDummyPowersRequest(GameRequest):
|
||||
"""Make powers AI-controlled (master only)."""
|
||||
|
||||
name: Literal["set_dummy_powers"] = "set_dummy_powers"
|
||||
power_names: List[str]
|
||||
|
||||
|
||||
class DeleteGameRequest(GameRequest):
|
||||
"""Delete game (master only)."""
|
||||
|
||||
name: Literal["delete_game"] = "delete_game"
|
||||
|
||||
|
||||
class LeaveGameRequest(GameRequest):
|
||||
"""Leave game."""
|
||||
|
||||
name: Literal["leave_game"] = "leave_game"
|
||||
|
||||
|
||||
|
|
@ -226,13 +261,16 @@ class LeaveGameRequest(GameRequest):
|
|||
# Response Messages (Server -> Client)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OkResponse(BaseResponse):
|
||||
"""Generic success response."""
|
||||
|
||||
name: Literal["ok"] = "ok"
|
||||
|
||||
|
||||
class ErrorResponse(BaseResponse):
|
||||
"""Error response with error type and message."""
|
||||
|
||||
name: Literal["error"] = "error"
|
||||
error_type: str
|
||||
message: str
|
||||
|
|
@ -240,60 +278,70 @@ class ErrorResponse(BaseResponse):
|
|||
|
||||
class DataTokenResponse(BaseResponse):
|
||||
"""Contains authentication token."""
|
||||
|
||||
name: Literal["data_token"] = "data_token"
|
||||
data: str # The authentication token
|
||||
|
||||
|
||||
class DataGameResponse(BaseResponse):
|
||||
"""Contains full game object."""
|
||||
|
||||
name: Literal["data_game"] = "data_game"
|
||||
data: Dict[str, Any] # The complete game state
|
||||
|
||||
|
||||
class DataGameInfoResponse(BaseResponse):
|
||||
"""Contains game metadata."""
|
||||
|
||||
name: Literal["data_game_info"] = "data_game_info"
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class DataGamesResponse(BaseResponse):
|
||||
"""List of game information."""
|
||||
|
||||
name: Literal["data_games"] = "data_games"
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class DataMapsResponse(BaseResponse):
|
||||
"""Available maps information."""
|
||||
|
||||
name: Literal["data_maps"] = "data_maps"
|
||||
data: List[str]
|
||||
|
||||
|
||||
class DataPowerNamesResponse(BaseResponse):
|
||||
"""List of power names."""
|
||||
|
||||
name: Literal["data_power_names"] = "data_power_names"
|
||||
data: List[str]
|
||||
|
||||
|
||||
class DataPossibleOrdersResponse(BaseResponse):
|
||||
"""Legal orders and locations."""
|
||||
|
||||
name: Literal["data_possible_orders"] = "data_possible_orders"
|
||||
data: Dict[str, List[str]] # Location -> list of possible orders
|
||||
|
||||
|
||||
class DataGamePhasesResponse(BaseResponse):
|
||||
"""Historical game phases."""
|
||||
|
||||
name: Literal["data_game_phases"] = "data_game_phases"
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class DataSavedGameResponse(BaseResponse):
|
||||
"""Exported game JSON."""
|
||||
|
||||
name: Literal["data_saved_game"] = "data_saved_game"
|
||||
data: Dict[str, Any]
|
||||
|
||||
|
||||
class DataPortResponse(BaseResponse):
|
||||
"""DAIDE port number."""
|
||||
|
||||
name: Literal["data_port"] = "data_port"
|
||||
data: int
|
||||
|
||||
|
|
@ -302,8 +350,10 @@ class DataPortResponse(BaseResponse):
|
|||
# Notification Messages (Server -> Client)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GameProcessedNotification(BaseNotification):
|
||||
"""Phase completed, new orders phase."""
|
||||
|
||||
name: Literal["game_processed"] = "game_processed"
|
||||
game_id: str
|
||||
phase: str
|
||||
|
|
@ -312,6 +362,7 @@ class GameProcessedNotification(BaseNotification):
|
|||
|
||||
class GamePhaseUpdateNotification(BaseNotification):
|
||||
"""Game state changed."""
|
||||
|
||||
name: Literal["game_phase_update"] = "game_phase_update"
|
||||
game_id: str
|
||||
phase: str
|
||||
|
|
@ -320,6 +371,7 @@ class GamePhaseUpdateNotification(BaseNotification):
|
|||
|
||||
class GameStatusUpdateNotification(BaseNotification):
|
||||
"""Game status changed (forming/active/paused/completed)."""
|
||||
|
||||
name: Literal["game_status_update"] = "game_status_update"
|
||||
game_id: str
|
||||
status: str
|
||||
|
|
@ -327,6 +379,7 @@ class GameStatusUpdateNotification(BaseNotification):
|
|||
|
||||
class PowersControllersNotification(BaseNotification):
|
||||
"""Power control assignments changed."""
|
||||
|
||||
name: Literal["powers_controllers"] = "powers_controllers"
|
||||
game_id: str
|
||||
controllers: Dict[str, str] # Power -> Controller mapping
|
||||
|
|
@ -334,6 +387,7 @@ class PowersControllersNotification(BaseNotification):
|
|||
|
||||
class PowerOrdersUpdateNotification(BaseNotification):
|
||||
"""Player submitted new orders."""
|
||||
|
||||
name: Literal["power_orders_update"] = "power_orders_update"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -343,15 +397,17 @@ class PowerOrdersUpdateNotification(BaseNotification):
|
|||
|
||||
class PowerOrdersFlagNotification(BaseNotification):
|
||||
"""Player order submission status."""
|
||||
|
||||
name: Literal["power_orders_flag"] = "power_orders_flag"
|
||||
game_id: str
|
||||
power_name: str
|
||||
power_name: str
|
||||
order_is_set: bool
|
||||
phase: str
|
||||
|
||||
|
||||
class PowerWaitFlagNotification(BaseNotification):
|
||||
"""Player wait flag changed."""
|
||||
|
||||
name: Literal["power_wait_flag"] = "power_wait_flag"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -360,6 +416,7 @@ class PowerWaitFlagNotification(BaseNotification):
|
|||
|
||||
class GameMessageReceivedNotification(BaseNotification):
|
||||
"""Diplomatic message received."""
|
||||
|
||||
name: Literal["game_message_received"] = "game_message_received"
|
||||
game_id: str
|
||||
sender: str
|
||||
|
|
@ -371,6 +428,7 @@ class GameMessageReceivedNotification(BaseNotification):
|
|||
|
||||
class VoteUpdatedNotification(BaseNotification):
|
||||
"""Draw votes changed (omniscient view)."""
|
||||
|
||||
name: Literal["vote_updated"] = "vote_updated"
|
||||
game_id: str
|
||||
votes: Dict[str, str] # Power -> vote mapping
|
||||
|
|
@ -378,6 +436,7 @@ class VoteUpdatedNotification(BaseNotification):
|
|||
|
||||
class VoteCountUpdatedNotification(BaseNotification):
|
||||
"""Vote count changed (observer view)."""
|
||||
|
||||
name: Literal["vote_count_updated"] = "vote_count_updated"
|
||||
game_id: str
|
||||
count_yes: int
|
||||
|
|
@ -386,6 +445,7 @@ class VoteCountUpdatedNotification(BaseNotification):
|
|||
|
||||
class PowerVoteUpdatedNotification(BaseNotification):
|
||||
"""Own power's vote changed."""
|
||||
|
||||
name: Literal["power_vote_updated"] = "power_vote_updated"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -394,12 +454,14 @@ class PowerVoteUpdatedNotification(BaseNotification):
|
|||
|
||||
class GameDeletedNotification(BaseNotification):
|
||||
"""Game removed from server."""
|
||||
|
||||
name: Literal["game_deleted"] = "game_deleted"
|
||||
game_id: str
|
||||
|
||||
|
||||
class OmniscientUpdatedNotification(BaseNotification):
|
||||
"""Observer permissions changed."""
|
||||
|
||||
name: Literal["omniscient_updated"] = "omniscient_updated"
|
||||
game_id: str
|
||||
omniscient_type: str
|
||||
|
|
@ -407,12 +469,14 @@ class OmniscientUpdatedNotification(BaseNotification):
|
|||
|
||||
class AccountDeletedNotification(BaseNotification):
|
||||
"""User account deleted."""
|
||||
|
||||
name: Literal["account_deleted"] = "account_deleted"
|
||||
username: str
|
||||
|
||||
|
||||
class ClearedCentersNotification(BaseNotification):
|
||||
"""Supply centers cleared."""
|
||||
|
||||
name: Literal["cleared_centers"] = "cleared_centers"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -420,6 +484,7 @@ class ClearedCentersNotification(BaseNotification):
|
|||
|
||||
class ClearedOrdersNotification(BaseNotification):
|
||||
"""Orders cleared."""
|
||||
|
||||
name: Literal["cleared_orders"] = "cleared_orders"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -428,6 +493,7 @@ class ClearedOrdersNotification(BaseNotification):
|
|||
|
||||
class ClearedUnitsNotification(BaseNotification):
|
||||
"""Units cleared."""
|
||||
|
||||
name: Literal["cleared_units"] = "cleared_units"
|
||||
game_id: str
|
||||
power_name: str
|
||||
|
|
@ -514,24 +580,25 @@ WebSocketMessage = Union[RequestMessage, ResponseMessage, NotificationMessage]
|
|||
# Utility Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
|
||||
"""
|
||||
Parse a raw WebSocket message dictionary into the appropriate pydantic model.
|
||||
|
||||
|
||||
Args:
|
||||
data: Raw message dictionary from WebSocket
|
||||
|
||||
|
||||
Returns:
|
||||
Parsed message object
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If message cannot be parsed or is of unknown type
|
||||
"""
|
||||
if not isinstance(data, dict) or "name" not in data:
|
||||
raise ValueError("Invalid message format: missing 'name' field")
|
||||
|
||||
|
||||
message_name = data["name"]
|
||||
|
||||
|
||||
# Map message names to their corresponding classes
|
||||
message_classes = {
|
||||
# Requests
|
||||
|
|
@ -592,22 +659,22 @@ def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
|
|||
"cleared_orders": ClearedOrdersNotification,
|
||||
"cleared_units": ClearedUnitsNotification,
|
||||
}
|
||||
|
||||
|
||||
message_class = message_classes.get(message_name)
|
||||
if message_class is None:
|
||||
raise ValueError(f"Unknown message type: {message_name}")
|
||||
|
||||
|
||||
return message_class(**data)
|
||||
|
||||
|
||||
def serialize_message(message: WebSocketMessage) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize a pydantic message object to a dictionary for WebSocket transmission.
|
||||
|
||||
|
||||
Args:
|
||||
message: Pydantic message object
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the message
|
||||
"""
|
||||
return message.model_dump(exclude_none=True)
|
||||
return message.model_dump(exclude_none=True)
|
||||
|
|
|
|||
|
|
@ -50,13 +50,13 @@ class MultiBotLauncher:
|
|||
|
||||
# Default power to model mapping
|
||||
self.default_models = {
|
||||
"AUSTRIA": "gpt-3.5-turbo",
|
||||
"ENGLAND": "gpt-4",
|
||||
"FRANCE": "claude-3-haiku",
|
||||
"GERMANY": "gpt-3.5-turbo",
|
||||
"ITALY": "gemini-pro",
|
||||
"RUSSIA": "gpt-4",
|
||||
"TURKEY": "claude-3-sonnet",
|
||||
"AUSTRIA": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"ENGLAND": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"FRANCE": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"GERMANY": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"ITALY": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"RUSSIA": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"TURKEY": "gemini-2.5-flash-lite-preview-06-17",
|
||||
}
|
||||
|
||||
async def create_game(self, creator_power: str = "FRANCE") -> str:
|
||||
|
|
@ -99,7 +99,12 @@ class MultiBotLauncher:
|
|||
return game_id
|
||||
|
||||
def launch_bot(
|
||||
self, power: str, model: str, game_id: str, log_level: str = "INFO"
|
||||
self,
|
||||
power: str,
|
||||
model: str,
|
||||
game_id: str,
|
||||
log_level: str = "INFO",
|
||||
negotiation_rounds: int = 3,
|
||||
) -> subprocess.Popen:
|
||||
"""
|
||||
Launch a single bot process.
|
||||
|
|
@ -134,6 +139,8 @@ class MultiBotLauncher:
|
|||
game_id,
|
||||
"--log-level",
|
||||
log_level,
|
||||
"--negotiation-rounds",
|
||||
str(negotiation_rounds),
|
||||
]
|
||||
|
||||
logger.info(f"Launching bot for {power} with model {model}")
|
||||
|
|
@ -156,7 +163,8 @@ class MultiBotLauncher:
|
|||
models: Optional[Dict[str, str]] = None,
|
||||
powers: Optional[List[str]] = None,
|
||||
log_level: str = "INFO",
|
||||
stagger_delay: float = 2.0,
|
||||
stagger_delay: float = 0.5,
|
||||
negotiation_rounds: int = 3,
|
||||
):
|
||||
"""
|
||||
Launch bots for all specified powers.
|
||||
|
|
@ -180,13 +188,11 @@ class MultiBotLauncher:
|
|||
model = models.get(power, "gpt-3.5-turbo")
|
||||
|
||||
try:
|
||||
process = self.launch_bot(power, model, game_id, log_level)
|
||||
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
|
||||
self.bot_processes.append(process)
|
||||
self.process_to_power[process] = power
|
||||
|
||||
logger.info(
|
||||
f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})"
|
||||
)
|
||||
logger.info(f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})")
|
||||
|
||||
# Stagger the launches to avoid overwhelming the server
|
||||
if i < len(powers) - 1: # Don't delay after the last bot
|
||||
|
|
@ -209,11 +215,11 @@ class MultiBotLauncher:
|
|||
try:
|
||||
while self.bot_processes:
|
||||
active_processes = []
|
||||
|
||||
|
||||
# Collect all stdout file descriptors from active processes
|
||||
stdout_fds = []
|
||||
fd_to_process = {}
|
||||
|
||||
|
||||
for process in self.bot_processes:
|
||||
if process.poll() is None: # Still running
|
||||
active_processes.append(process)
|
||||
|
|
@ -224,15 +230,13 @@ class MultiBotLauncher:
|
|||
# Process has ended
|
||||
return_code = process.returncode
|
||||
power = self.process_to_power.get(process, "UNKNOWN")
|
||||
logger.info(
|
||||
f"{power} bot process {process.pid} ended with code {return_code}"
|
||||
)
|
||||
logger.info(f"{power} bot process {process.pid} ended with code {return_code}")
|
||||
|
||||
# Read any remaining output
|
||||
remaining_output = process.stdout.read()
|
||||
if remaining_output:
|
||||
print(f"{power}_{process.pid} final output: {remaining_output}")
|
||||
|
||||
|
||||
# Clean up the power mapping
|
||||
self.process_to_power.pop(process, None)
|
||||
|
||||
|
|
@ -241,20 +245,20 @@ class MultiBotLauncher:
|
|||
if not self.bot_processes:
|
||||
logger.info("All bots have finished")
|
||||
break
|
||||
|
||||
|
||||
# Use select to check which processes have output ready (Unix only)
|
||||
if stdout_fds and hasattr(select, 'select'):
|
||||
if stdout_fds and hasattr(select, "select"):
|
||||
try:
|
||||
ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout
|
||||
|
||||
|
||||
for fd in ready_fds:
|
||||
process = fd_to_process[fd]
|
||||
power = self.process_to_power.get(process, "UNKNOWN")
|
||||
|
||||
|
||||
# Read available lines (but limit to prevent monopolizing)
|
||||
lines_read = 0
|
||||
max_lines_per_process = 10
|
||||
|
||||
|
||||
while lines_read < max_lines_per_process:
|
||||
try:
|
||||
line = process.stdout.readline()
|
||||
|
|
@ -264,7 +268,7 @@ class MultiBotLauncher:
|
|||
lines_read += 1
|
||||
except:
|
||||
break
|
||||
|
||||
|
||||
except (OSError, ValueError):
|
||||
# Fallback if select fails
|
||||
self._fallback_read_output(active_processes)
|
||||
|
|
@ -283,11 +287,11 @@ class MultiBotLauncher:
|
|||
"""Fallback method for reading output when select is not available."""
|
||||
for process in active_processes:
|
||||
power = self.process_to_power.get(process, "UNKNOWN")
|
||||
|
||||
|
||||
# Read limited lines per process to prevent monopolizing
|
||||
lines_read = 0
|
||||
max_lines_per_process = 3 # More conservative for fallback
|
||||
|
||||
|
||||
while lines_read < max_lines_per_process:
|
||||
try:
|
||||
line = process.stdout.readline()
|
||||
|
|
@ -324,6 +328,7 @@ class MultiBotLauncher:
|
|||
models: Optional[Dict[str, str]] = None,
|
||||
log_level: str = "INFO",
|
||||
creator_power: str = "FRANCE",
|
||||
negotiation_rounds: int = 3,
|
||||
):
|
||||
"""
|
||||
Create a game and launch all bots for a complete game.
|
||||
|
|
@ -342,7 +347,12 @@ class MultiBotLauncher:
|
|||
await asyncio.sleep(2)
|
||||
|
||||
# Launch all bots
|
||||
await self.launch_all_bots(game_id, models, log_level=log_level)
|
||||
await self.launch_all_bots(
|
||||
game_id,
|
||||
models,
|
||||
log_level=log_level,
|
||||
negotiation_rounds=negotiation_rounds,
|
||||
)
|
||||
|
||||
# Monitor the bots
|
||||
self.monitor_bots()
|
||||
|
|
@ -358,6 +368,7 @@ class MultiBotLauncher:
|
|||
powers: List[str],
|
||||
models: Optional[Dict[str, str]] = None,
|
||||
log_level: str = "INFO",
|
||||
negotiation_rounds: int = 3,
|
||||
):
|
||||
"""
|
||||
Launch bots to join an existing game.
|
||||
|
|
@ -372,7 +383,13 @@ class MultiBotLauncher:
|
|||
self.game_id = game_id
|
||||
|
||||
# Launch bots for specified powers
|
||||
await self.launch_all_bots(game_id, models, powers, log_level)
|
||||
await self.launch_all_bots(
|
||||
game_id,
|
||||
models,
|
||||
powers,
|
||||
log_level,
|
||||
negotiation_rounds=negotiation_rounds,
|
||||
)
|
||||
|
||||
# Monitor the bots
|
||||
self.monitor_bots()
|
||||
|
|
@ -391,18 +408,16 @@ def parse_arguments():
|
|||
parser.add_argument("--port", type=int, default=8432, help="Server port")
|
||||
parser.add_argument("--username-base", default="bot", help="Base username for bots")
|
||||
parser.add_argument("--password", default="password", help="Password for all bots")
|
||||
parser.add_argument(
|
||||
"--game-id", help="Game ID to join (creates new if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--powers", nargs="+", help="Powers to launch bots for (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", help="Comma-separated list of models in power order"
|
||||
)
|
||||
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
|
||||
parser.add_argument("--powers", nargs="+", help="Powers to launch bots for (default: all)")
|
||||
parser.add_argument("--models", help="Comma-separated list of models in power order")
|
||||
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
||||
parser.add_argument("--creator-power", default="FRANCE", help="Power that creates the game")
|
||||
parser.add_argument(
|
||||
"--creator-power", default="FRANCE", help="Power that creates the game"
|
||||
"--negotiation-rounds",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
|
@ -425,9 +440,7 @@ async def main():
|
|||
model_list = [m.strip() for m in args.models.split(",")]
|
||||
powers = args.powers or list(launcher.default_models.keys())
|
||||
if len(model_list) != len(powers):
|
||||
logger.error(
|
||||
f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})"
|
||||
)
|
||||
logger.error(f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})")
|
||||
return
|
||||
models = dict(zip(powers, model_list))
|
||||
|
||||
|
|
@ -440,6 +453,7 @@ async def main():
|
|||
powers=powers,
|
||||
models=models,
|
||||
log_level=args.log_level,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
)
|
||||
else:
|
||||
# Create new game and launch all bots
|
||||
|
|
@ -447,6 +461,7 @@ async def main():
|
|||
models=models,
|
||||
log_level=args.log_level,
|
||||
creator_power=args.creator_power,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ include = [".", ".."]
|
|||
|
||||
[tool.ruff]
|
||||
src = ["..", "."]
|
||||
line-length = 150
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["bot_client/tests"]
|
||||
|
|
@ -44,8 +45,4 @@ markers = [
|
|||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.4.1",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
dev = ["pytest>=8.4.1", "pytest-asyncio>=0.21.0", "websockets>=15.0.1"]
|
||||
|
|
|
|||
|
|
@ -13,25 +13,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, List
|
||||
import dotenv
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from websocket_diplomacy_client import (
|
||||
WebSocketDiplomacyClient,
|
||||
connect_to_diplomacy_server,
|
||||
)
|
||||
|
||||
from diplomacy.utils.exceptions import DiplomacyException, GameIdException
|
||||
|
||||
# Suppress warnings
|
||||
# os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40"
|
||||
# os.environ["GRPC_VERBOSITY"] = "ERROR"
|
||||
# os.environ["ABSL_MIN_LOG_LEVEL"] = "2"
|
||||
# os.environ["GRPC_POLL_STRATEGY"] = "poll"
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
|
||||
|
||||
|
||||
from diplomacy.engine.power import Power
|
||||
from diplomacy.utils.exceptions import GameIdException
|
||||
from diplomacy.communication.notifications import GameStatusUpdate
|
||||
from diplomacy.engine.message import Message
|
||||
|
||||
from ai_diplomacy.clients import load_model_client
|
||||
|
|
@ -40,15 +32,16 @@ from ai_diplomacy.game_history import GameHistory
|
|||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.initialization import initialize_agent_state_ext
|
||||
from config import Configuration
|
||||
from websocket_negotiations import (
|
||||
conduct_strategic_negotiation_round,
|
||||
should_participate_in_negotiations,
|
||||
get_negotiation_delay,
|
||||
)
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
config = Configuration()
|
||||
|
||||
if config.DEBUG:
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
# TODO: This, but better
|
||||
config: Configuration = None
|
||||
|
||||
|
||||
class SingleBotPlayer:
|
||||
|
|
@ -70,6 +63,7 @@ class SingleBotPlayer:
|
|||
hostname: str = "localhost",
|
||||
port: int = 8432,
|
||||
game_id: Optional[str] = None,
|
||||
negotiation_rounds: int = 3,
|
||||
):
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
|
@ -87,16 +81,24 @@ class SingleBotPlayer:
|
|||
# Bot state
|
||||
self.client: WebSocketDiplomacyClient
|
||||
self.agent: DiplomacyAgent
|
||||
self.game_history = GameHistory()
|
||||
self.game_history: GameHistory = GameHistory()
|
||||
self.running = True
|
||||
self.current_phase = None
|
||||
self.waiting_for_orders = False
|
||||
self.orders_submitted = False
|
||||
|
||||
# Negotiation settings
|
||||
self.negotiation_rounds = negotiation_rounds
|
||||
self.current_negotiation_round = 0
|
||||
self.negotiation_complete = False
|
||||
|
||||
# Track error stats
|
||||
self.error_stats: Dict[str, Dict[str, int]] = {
|
||||
self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}
|
||||
}
|
||||
self.error_stats: Dict[str, Dict[str, int]] = {self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}}
|
||||
|
||||
# Track messaging patterns for strategic communication
|
||||
self.message_counts: Dict[str, int] = {} # Messages received from each power
|
||||
self.response_counts: Dict[str, int] = {} # Responses sent to each power
|
||||
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
|
|
@ -122,9 +124,7 @@ class SingleBotPlayer:
|
|||
# Join or create game
|
||||
if self.game_id:
|
||||
logger.info(f"Joining existing game {self.game_id} as {self.power_name}")
|
||||
game = await self.client.join_game(
|
||||
game_id=self.game_id, power_name=self.power_name
|
||||
)
|
||||
await self.client.join_game(game_id=self.game_id, power_name=self.power_name)
|
||||
else:
|
||||
logger.info(f"Creating new game as {self.power_name}")
|
||||
await self.client.create_game(
|
||||
|
|
@ -142,9 +142,7 @@ class SingleBotPlayer:
|
|||
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
|
||||
|
||||
# Initialize agent state
|
||||
await initialize_agent_state_ext(
|
||||
self.agent, self.client.game, self.game_history, config.log_file_path
|
||||
)
|
||||
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
|
||||
|
||||
# Setup game event callbacks
|
||||
await self._setup_event_callbacks()
|
||||
|
|
@ -198,6 +196,16 @@ class SingleBotPlayer:
|
|||
self.current_phase = new_phase
|
||||
self.game_history.add_phase(new_phase)
|
||||
self.orders_submitted = False
|
||||
self.current_negotiation_round = 0
|
||||
self.negotiation_complete = False
|
||||
|
||||
# Log message statistics at phase transitions
|
||||
if hasattr(self, "message_counts") and self.message_counts:
|
||||
self.log_message_statistics()
|
||||
|
||||
# Check if we should start negotiations for movement phases (not adjustment phases)
|
||||
if new_phase.endswith("M"):
|
||||
await self._handle_negotiation_phase()
|
||||
|
||||
# Check if we need to submit orders for this new phase
|
||||
await self._check_if_orders_needed()
|
||||
|
|
@ -223,9 +231,7 @@ class SingleBotPlayer:
|
|||
def _on_message_received(self, game, notification):
|
||||
"""Handle incoming diplomatic messages."""
|
||||
message = notification.message
|
||||
logger.info(
|
||||
f"Message received from {message.sender} to {message.recipient}: {message.message}"
|
||||
)
|
||||
logger.info(f"Message received from {message.sender} to {message.recipient}: {message.message}")
|
||||
|
||||
# Add message to game history
|
||||
self.game_history.add_message(
|
||||
|
|
@ -235,12 +241,17 @@ class SingleBotPlayer:
|
|||
message_content=message.message,
|
||||
)
|
||||
|
||||
# Track message patterns
|
||||
if message.recipient == self.power_name and message.sender != self.power_name:
|
||||
self.message_counts[message.sender] = self.message_counts.get(message.sender, 0) + 1
|
||||
self._update_priority_contacts()
|
||||
|
||||
# If it's a private message to us, consider responding
|
||||
if message.recipient == self.power_name and message.sender != self.power_name:
|
||||
# Schedule the async processing in the event loop
|
||||
asyncio.create_task(self._consider_message_response(message))
|
||||
|
||||
def _on_status_update(self, game, notification):
|
||||
def _on_status_update(self, game, notification: GameStatusUpdate):
|
||||
"""Handle game status changes."""
|
||||
logger.info(f"Game status updated: {notification.status}")
|
||||
|
||||
|
|
@ -261,20 +272,50 @@ class SingleBotPlayer:
|
|||
# Check if it's a phase where we can submit orders
|
||||
current_short_phase = self.client.game.current_short_phase
|
||||
|
||||
# We submit orders in Movement and Retreat phases
|
||||
if current_short_phase.endswith("M") or current_short_phase.endswith("R"):
|
||||
# Check if we have units that can receive orders
|
||||
orderable_locations = self.client.game.get_orderable_locations(
|
||||
self.power_name
|
||||
)
|
||||
if orderable_locations:
|
||||
logger.info(f"Orders needed for phase {current_short_phase}")
|
||||
self.waiting_for_orders = True
|
||||
await self._submit_orders()
|
||||
else:
|
||||
logger.info(
|
||||
f"No orderable locations for {self.power_name} in {current_short_phase}"
|
||||
)
|
||||
logger.debug(f"Checking if orders needed for phase: {current_short_phase}")
|
||||
|
||||
# Movement and Retreat phases
|
||||
orderable_locations = self.client.game.get_orderable_locations(self.power_name)
|
||||
if orderable_locations:
|
||||
logger.info(f"Orders needed for {current_short_phase} phase - orderable locations: {orderable_locations}")
|
||||
self.waiting_for_orders = True
|
||||
await self._submit_orders()
|
||||
else:
|
||||
logger.info(f"No orderable locations for {self.power_name} in {current_short_phase}")
|
||||
|
||||
async def _submit_adjustment_orders(self, action_type, count):
|
||||
"""Submit build or disband orders for adjustment phase."""
|
||||
# FIXME: This whole function is horse shit.
|
||||
return
|
||||
if action_type == "build":
|
||||
# Get buildable locations for this power
|
||||
possible_orders = gather_possible_order(self.client.game, self.power_name)
|
||||
if not possible_orders:
|
||||
logger.warning(f"No possible build orders for {self.power_name}")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
return
|
||||
|
||||
# Filter for build orders (usually start with unit type + location)
|
||||
build_orders = [order for order in possible_orders if " - " not in order and any(order.startswith(unit) for unit in ["A ", "F "])]
|
||||
|
||||
logger.info(f"Available build orders for {self.power_name}: {build_orders}")
|
||||
|
||||
# Select up to 'count' build orders
|
||||
selected_orders = build_orders[:count]
|
||||
logger.info(f"Submitting build orders for {self.power_name}: {selected_orders}")
|
||||
|
||||
await self.client.set_orders(self.power_name, selected_orders)
|
||||
|
||||
elif action_type == "disband":
|
||||
# Get current units for disbanding
|
||||
|
||||
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
|
||||
await self.client.set_orders(self.power_name, disband_orders)
|
||||
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
logger.info(f"Adjustment orders submitted successfully for {self.power_name}")
|
||||
|
||||
async def _submit_orders(self):
|
||||
"""Generate and submit orders for the current phase."""
|
||||
|
|
@ -282,108 +323,280 @@ class SingleBotPlayer:
|
|||
logger.debug("Orders already submitted for this phase")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Generating orders...")
|
||||
current_phase = self.client.game.get_current_phase()
|
||||
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
|
||||
|
||||
# Get current board state
|
||||
board_state = self.client.game.get_state()
|
||||
# Get current board state
|
||||
board_state = self.client.game.get_state()
|
||||
|
||||
# Get possible orders
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
# Get possible orders
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
|
||||
if not possible_orders:
|
||||
logger.info("No possible orders, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
return
|
||||
|
||||
# Generate orders using AI
|
||||
orders = await get_valid_orders(
|
||||
game=self.client.game,
|
||||
client=self.agent.client,
|
||||
board_state=board_state,
|
||||
power_name=self.power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||
phase=self.client.game.get_current_phase(),
|
||||
)
|
||||
|
||||
# Submit orders
|
||||
if orders:
|
||||
logger.info(f"Submitting orders: {orders}")
|
||||
await self.client.set_orders(self.power_name, orders)
|
||||
|
||||
# Generate order diary entry
|
||||
await self.agent.generate_order_diary_entry(
|
||||
self.client.game,
|
||||
orders,
|
||||
config.log_file_path,
|
||||
)
|
||||
else:
|
||||
logger.info("No valid orders generated, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
|
||||
|
||||
if not possible_orders:
|
||||
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
logger.info("Orders submitted successfully")
|
||||
# Call the no wait so we don't sit around for the turns to end.
|
||||
self.client.game.no_wait()
|
||||
return
|
||||
|
||||
except DiplomacyException as e:
|
||||
logger.error(f"Error submitting orders: {e}", exc_info=True)
|
||||
# FIXME: I don't think we want to do this. Likely want to retry again multiple times.
|
||||
#
|
||||
# Submit empty orders as fallback
|
||||
try:
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Failed to submit fallback orders: {fallback_error}")
|
||||
# Generate orders using AI
|
||||
orders = await get_valid_orders(
|
||||
game=self.client.game,
|
||||
client=self.agent.client,
|
||||
board_state=board_state,
|
||||
power_name=self.power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||
phase=self.client.game.get_current_phase(),
|
||||
)
|
||||
|
||||
# Submit orders
|
||||
if orders:
|
||||
logger.info(f"Submitting orders: {orders}")
|
||||
await self.client.set_orders(self.power_name, orders)
|
||||
|
||||
# Generate order diary entry
|
||||
await self.agent.generate_order_diary_entry(
|
||||
self.client.game,
|
||||
orders,
|
||||
config.log_file_path,
|
||||
)
|
||||
else:
|
||||
logger.info("No valid orders generated, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
logger.info("Orders submitted successfully")
|
||||
# Call the no wait so we don't sit around for the turns to end.
|
||||
# TODO: We probably don't want to call this here.
|
||||
# We want to call it when negotiations end,
|
||||
self.client.game.no_wait()
|
||||
|
||||
async def _analyze_phase_results(self):
|
||||
"""Analyze the results of the previous phase."""
|
||||
try:
|
||||
logger.info("Analyzing phase results...")
|
||||
logger.info("Analyzing phase results...")
|
||||
|
||||
# Get current board state after processing
|
||||
board_state = self.client.game.get_state()
|
||||
# Get current board state after processing
|
||||
board_state = self.client.game.get_state()
|
||||
|
||||
# Generate a simple phase summary
|
||||
phase_summary = f"Phase {self.current_phase} completed."
|
||||
# Generate a simple phase summary
|
||||
phase_summary = f"Phase {self.current_phase} completed."
|
||||
|
||||
# Update agent state based on results
|
||||
await self.agent.analyze_phase_and_update_state(
|
||||
game=self.client.game,
|
||||
board_state=board_state,
|
||||
phase_summary=phase_summary,
|
||||
# Update agent state based on results
|
||||
await self.agent.analyze_phase_and_update_state(
|
||||
game=self.client.game,
|
||||
board_state=board_state,
|
||||
phase_summary=phase_summary,
|
||||
game_history=self.game_history,
|
||||
log_file_path=config.log_file_path,
|
||||
)
|
||||
|
||||
logger.info("Phase analysis complete")
|
||||
|
||||
async def _handle_negotiation_phase(self):
|
||||
"""Handle the negotiation phase for movement turns."""
|
||||
# Check if we should participate in negotiations
|
||||
if not await should_participate_in_negotiations(self.client, self.agent):
|
||||
logger.info(f"{self.power_name} will not participate in negotiations this phase")
|
||||
self.negotiation_complete = True
|
||||
return
|
||||
|
||||
logger.info(f"Starting negotiation phase for {self.power_name}")
|
||||
# TODO: This doesn't need a specific number of negotiation rounds, though it should have a top number of messages this turn so they don't blabber on forever.
|
||||
|
||||
# Conduct negotiations for the specified number of rounds
|
||||
for round_num in range(1, self.negotiation_rounds + 1):
|
||||
self.current_negotiation_round = round_num
|
||||
|
||||
logger.info(f"Negotiation round {round_num}/{self.negotiation_rounds} for {self.power_name}")
|
||||
|
||||
# Use strategic negotiation that analyzes recent messages
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=self.client,
|
||||
agent=self.agent,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
log_file_path=config.log_file_path,
|
||||
round_number=round_num,
|
||||
max_rounds=self.negotiation_rounds,
|
||||
)
|
||||
|
||||
logger.info("Phase analysis complete")
|
||||
if not success:
|
||||
logger.info(f"No messages sent in round {round_num} for {self.power_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phase results: {e}", exc_info=True)
|
||||
# Wait between rounds to allow other bots to respond
|
||||
if round_num < self.negotiation_rounds:
|
||||
delay = get_negotiation_delay(round_num, self.negotiation_rounds)
|
||||
logger.debug(f"Waiting {delay}s before next negotiation round")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
self.negotiation_complete = True
|
||||
logger.info(f"Negotiation phase complete for {self.power_name}")
|
||||
|
||||
async def _consider_message_response(self, message: Message):
|
||||
"""Consider whether to respond to a diplomatic message."""
|
||||
try:
|
||||
# Simple logic: if someone greets us, greet back
|
||||
if any(
|
||||
word in message.message.lower() for word in ["hello", "hi", "greetings"]
|
||||
):
|
||||
response = f"Hello {message.sender}! Good to hear from you."
|
||||
await self.client.game.send_game_message(
|
||||
sender=self.power_name, recipient=message.sender, message=response
|
||||
# Only respond to messages directed at us specifically
|
||||
if message.recipient != self.power_name:
|
||||
return
|
||||
|
||||
# Don't respond to our own messages
|
||||
if message.sender == self.power_name:
|
||||
return
|
||||
|
||||
logger.info(f"Considering response to message from {message.sender}: {message.message[:50]}...")
|
||||
|
||||
# Enhanced heuristic: respond to direct questions, proposals, and strategic keywords
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = [
|
||||
"alliance",
|
||||
"deal",
|
||||
"propose",
|
||||
"agreement",
|
||||
"support",
|
||||
"attack",
|
||||
"coordinate",
|
||||
"move",
|
||||
"order",
|
||||
"help",
|
||||
"work together",
|
||||
"partner",
|
||||
"enemy",
|
||||
"threat",
|
||||
"negotiate",
|
||||
"discuss",
|
||||
"plan",
|
||||
"strategy",
|
||||
"bounce",
|
||||
"convoy",
|
||||
"retreat",
|
||||
]
|
||||
|
||||
should_respond = any(
|
||||
[
|
||||
"?" in message.message, # Questions
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
|
||||
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
|
||||
len(message.message.split()) > 15, # Longer messages suggest they want engagement
|
||||
message.sender in self.priority_contacts, # Priority contacts
|
||||
]
|
||||
)
|
||||
|
||||
if should_respond:
|
||||
# Generate a contextual response using AI
|
||||
# Get current game state for context
|
||||
board_state = self.client.get_state()
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
|
||||
# Create a simple conversation context
|
||||
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated()]
|
||||
|
||||
# Generate response using the agent's conversation capabilities
|
||||
responses = await self.agent.client.get_conversation_reply(
|
||||
game=self.client.game,
|
||||
board_state=board_state,
|
||||
power_name=self.power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
game_phase=self.client.get_current_short_phase(),
|
||||
log_file_path=config.log_file_path,
|
||||
active_powers=active_powers,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||
)
|
||||
logger.info(f"Sent response to {message.sender}: {response}")
|
||||
|
||||
# Send the first response if any were generated
|
||||
if responses and len(responses) > 0:
|
||||
response_content = responses[0].get("content", "").strip()
|
||||
if response_content:
|
||||
await self.client.send_message(
|
||||
sender=self.power_name,
|
||||
recipient=message.sender,
|
||||
message=response_content,
|
||||
phase=self.client.get_current_short_phase(),
|
||||
)
|
||||
|
||||
# Add to game history
|
||||
self.game_history.add_message(
|
||||
phase_name=self.client.get_current_short_phase(),
|
||||
sender=self.power_name,
|
||||
recipient=message.sender,
|
||||
message_content=response_content,
|
||||
)
|
||||
|
||||
# Track response patterns
|
||||
self.response_counts[message.sender] = self.response_counts.get(message.sender, 0) + 1
|
||||
|
||||
# Add to agent's journal
|
||||
self.agent.add_journal_entry(
|
||||
f"Responded to {message.sender} in {self.client.get_current_short_phase()}: {response_content[:100]}..."
|
||||
)
|
||||
|
||||
logger.info(f"Sent AI response to {message.sender}: {response_content[:50]}...")
|
||||
else:
|
||||
logger.debug(f"AI generated empty response to {message.sender}")
|
||||
else:
|
||||
logger.debug(f"AI generated no responses to {message.sender}")
|
||||
else:
|
||||
logger.debug(f"Decided not to respond to message from {message.sender}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error responding to message: {e}")
|
||||
logger.error(f"Error responding to message: {e}", exc_info=True)
|
||||
|
||||
def _update_priority_contacts(self) -> None:
|
||||
"""Update the list of priority contacts based on messaging patterns."""
|
||||
# Sort powers by message count (descending) and take top 3-4
|
||||
sorted_contacts = sorted(self.message_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Keep top 4 most active contacts as priority
|
||||
self.priority_contacts = [contact[0] for contact in sorted_contacts[:4]]
|
||||
|
||||
logger.debug(f"Updated priority contacts for {self.power_name}: {self.priority_contacts}")
|
||||
|
||||
def get_message_statistics(self) -> Dict[str, any]:
|
||||
"""Get comprehensive statistics about messaging patterns."""
|
||||
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated() and p_name != self.power_name]
|
||||
|
||||
stats = {
|
||||
"power_name": self.power_name,
|
||||
"total_messages_received": sum(self.message_counts.values()),
|
||||
"total_responses_sent": sum(self.response_counts.values()),
|
||||
"message_counts_by_power": dict(self.message_counts),
|
||||
"response_counts_by_power": dict(self.response_counts),
|
||||
"priority_contacts": list(self.priority_contacts),
|
||||
"response_rate_by_power": {},
|
||||
"active_powers": active_powers,
|
||||
"current_phase": self.current_phase,
|
||||
}
|
||||
|
||||
# Calculate response rates
|
||||
for power, received in self.message_counts.items():
|
||||
sent = self.response_counts.get(power, 0)
|
||||
stats["response_rate_by_power"][power] = sent / received if received > 0 else 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def log_message_statistics(self) -> None:
|
||||
"""Log current message statistics for analysis."""
|
||||
stats = self.get_message_statistics()
|
||||
|
||||
logger.info(f"Message Statistics for {self.power_name}:")
|
||||
logger.info(f" Total messages received: {stats['total_messages_received']}")
|
||||
logger.info(f" Total responses sent: {stats['total_responses_sent']}")
|
||||
logger.info(f" Priority contacts: {stats['priority_contacts']}")
|
||||
|
||||
for power in stats["active_powers"]:
|
||||
received = stats["message_counts_by_power"].get(power, 0)
|
||||
sent = stats["response_counts_by_power"].get(power, 0)
|
||||
rate = stats["response_rate_by_power"].get(power, 0.0)
|
||||
logger.info(f" {power}: {received} received, {sent} sent, {rate:.1%} response rate")
|
||||
|
||||
async def run(self):
|
||||
"""Main bot loop."""
|
||||
|
|
@ -408,9 +621,7 @@ class SingleBotPlayer:
|
|||
else:
|
||||
logger.info("Bot shutting down")
|
||||
except GameIdException:
|
||||
logger.error(
|
||||
f"Game with id {self.game_id} does not exist on the server. Exiting..."
|
||||
)
|
||||
logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
|
||||
finally:
|
||||
await self.cleanup()
|
||||
|
||||
|
|
@ -437,10 +648,14 @@ def parse_arguments():
|
|||
parser.add_argument("--password", default="password", help="Bot password")
|
||||
parser.add_argument("--power", default="FRANCE", help="Power to control")
|
||||
parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use")
|
||||
parser.add_argument(
|
||||
"--game-id", help="Game ID to join (creates new if not specified)"
|
||||
)
|
||||
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
|
||||
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
||||
parser.add_argument(
|
||||
"--negotiation-rounds",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -459,8 +674,9 @@ async def main():
|
|||
power_name=args.power,
|
||||
model_name=args.model,
|
||||
game_id=args.game_id,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
)
|
||||
|
||||
config = Configuration(args.power)
|
||||
await bot.run()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,7 @@ class WebSocketDiplomacyClient:
|
|||
but communicates with a remote server via WebSocket connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False
|
||||
):
|
||||
def __init__(self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False):
|
||||
"""
|
||||
Initialize the WebSocket client.
|
||||
|
||||
|
|
@ -199,9 +197,7 @@ class WebSocketDiplomacyClient:
|
|||
|
||||
return await self.channel.get_available_maps()
|
||||
|
||||
async def set_orders(
|
||||
self, power_name: str, orders: List[str], wait: Optional[bool] = None
|
||||
) -> None:
|
||||
async def set_orders(self, power_name: str, orders: List[str], wait: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Set orders for a power.
|
||||
|
||||
|
|
@ -246,9 +242,7 @@ class WebSocketDiplomacyClient:
|
|||
else:
|
||||
await self.game.no_wait()
|
||||
|
||||
async def send_message(
|
||||
self, sender: str, recipient: str, message: str, phase: Optional[str] = None
|
||||
) -> None:
|
||||
async def send_message(self, sender: str, recipient: str, message: str, phase: Optional[str] = None) -> None:
|
||||
"""
|
||||
Send a diplomatic message.
|
||||
|
||||
|
|
@ -269,6 +263,57 @@ class WebSocketDiplomacyClient:
|
|||
logger.debug(f"Sending message from {sender} to {recipient}: {message}")
|
||||
await self.game.send_game_message(message=msg)
|
||||
|
||||
async def send_broadcast_message(self, sender: str, message: str, phase: Optional[str] = None) -> None:
|
||||
"""
|
||||
Send a broadcast message to all active powers.
|
||||
|
||||
Args:
|
||||
sender: Sending power name
|
||||
message: Message content
|
||||
phase: Game phase (uses current phase if None)
|
||||
"""
|
||||
if not self.game:
|
||||
raise DiplomacyException("Must join a game first")
|
||||
|
||||
if phase is None:
|
||||
phase = self.game.current_short_phase
|
||||
|
||||
# Send to all active powers
|
||||
active_powers = [power_name for power_name, power in self.powers.items() if not power.is_eliminated() and power_name != sender]
|
||||
|
||||
for recipient in active_powers:
|
||||
await self.send_message(sender, recipient, message, phase)
|
||||
|
||||
logger.debug(f"Broadcast message from {sender} to {len(active_powers)} powers: {message}")
|
||||
|
||||
async def get_recent_messages(self, phase: Optional[str] = None, limit: int = 50) -> List[Message]:
|
||||
"""
|
||||
Get recent messages from the game.
|
||||
|
||||
Args:
|
||||
phase: Filter by specific phase (None for current phase)
|
||||
limit: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
List of recent messages
|
||||
"""
|
||||
if not self.game:
|
||||
raise DiplomacyException("Must join a game first")
|
||||
|
||||
if phase is None:
|
||||
phase = self.game.current_short_phase
|
||||
|
||||
# Get messages from the game and filter by phase
|
||||
all_messages = list(self.messages.values())
|
||||
|
||||
# Filter by phase if specified
|
||||
filtered_messages = [msg for msg in all_messages if msg.phase == phase]
|
||||
|
||||
# Sort by time (most recent first) and limit
|
||||
filtered_messages.sort(key=lambda m: m.time_sent or 0, reverse=True)
|
||||
|
||||
return filtered_messages[:limit]
|
||||
|
||||
async def process_game(self) -> None:
|
||||
"""
|
||||
Force the game to process immediately (admin/moderator only).
|
||||
|
|
@ -289,9 +334,7 @@ class WebSocketDiplomacyClient:
|
|||
logger.debug("Synchronizing game state")
|
||||
await self.game.synchronize()
|
||||
|
||||
async def get_phase_history(
|
||||
self, from_phase: Optional[str] = None, to_phase: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def get_phase_history(self, from_phase: Optional[str] = None, to_phase: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get phase history for the game.
|
||||
|
||||
|
|
@ -305,9 +348,7 @@ class WebSocketDiplomacyClient:
|
|||
if not self.game:
|
||||
raise DiplomacyException("Must join a game first")
|
||||
|
||||
return await self.game.get_phase_history(
|
||||
from_phase=from_phase, to_phase=to_phase
|
||||
)
|
||||
return await self.game.get_phase_history(from_phase=from_phase, to_phase=to_phase)
|
||||
|
||||
async def vote(self, power_name: str, vote: str) -> None:
|
||||
"""
|
||||
|
|
|
|||
326
bot_client/websocket_negotiations.py
Normal file
326
bot_client/websocket_negotiations.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
"""
|
||||
WebSocket-specific negotiation logic for single bot players.
|
||||
|
||||
This module provides negotiation capabilities for individual bots connected
|
||||
via WebSocket, adapted from the multi-agent negotiation system in the main
|
||||
ai_diplomacy package.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
from diplomacy.engine.message import GLOBAL
|
||||
|
||||
from ai_diplomacy.utils import gather_possible_orders
|
||||
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
|
||||
|
||||
async def conduct_single_bot_negotiation(
|
||||
client: "WebSocketDiplomacyClient",
|
||||
agent: "DiplomacyAgent",
|
||||
game_history: "GameHistory",
|
||||
model_error_stats: Dict[str, Dict[str, int]],
|
||||
log_file_path: str,
|
||||
max_rounds: int = 3,
|
||||
round_number: int = 1,
|
||||
prioritize_targets: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Conduct negotiation for a single bot during one negotiation round.
|
||||
|
||||
This function handles message generation and sending for one bot during
|
||||
a negotiation phase. Unlike the multi-agent version, this focuses on
|
||||
a single power and sends messages via WebSocket.
|
||||
|
||||
Args:
|
||||
client: WebSocket diplomacy client
|
||||
agent: The bot's AI agent
|
||||
game_history: Game history tracker
|
||||
model_error_stats: Error statistics tracking
|
||||
log_file_path: Path for logging
|
||||
max_rounds: Maximum number of negotiation rounds
|
||||
round_number: Current round number (1-indexed)
|
||||
prioritize_targets: Optional list of powers to prioritize for messaging
|
||||
|
||||
Returns:
|
||||
True if messages were sent successfully, False otherwise
|
||||
"""
|
||||
power_name = agent.power_name
|
||||
logger.info(f"Starting negotiation round {round_number}/{max_rounds} for {power_name}")
|
||||
|
||||
# Check if this power is eliminated
|
||||
if client.get_power(power_name).is_eliminated():
|
||||
logger.info(f"{power_name} is eliminated, skipping negotiation")
|
||||
return False
|
||||
|
||||
# Check if this power has any orderable locations
|
||||
possible_orders = gather_possible_orders(client.game, power_name)
|
||||
if not possible_orders:
|
||||
logger.info(f"No orderable locations for {power_name}, skipping negotiation")
|
||||
return False
|
||||
|
||||
# Get active powers for context
|
||||
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated()]
|
||||
|
||||
# Prioritize message targets if specified
|
||||
message_targets = prioritize_targets if prioritize_targets else active_powers
|
||||
message_targets = [p for p in message_targets if p in active_powers and p != power_name]
|
||||
|
||||
# Generate conversation messages using the AI agent
|
||||
board_state = client.get_state()
|
||||
|
||||
messages = await agent.client.get_conversation_reply(
|
||||
game=client.game,
|
||||
board_state=board_state,
|
||||
power_name=power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=game_history,
|
||||
game_phase=client.get_current_short_phase(),
|
||||
log_file_path=log_file_path,
|
||||
active_powers=active_powers,
|
||||
agent_goals=agent.goals,
|
||||
agent_relationships=agent.relationships,
|
||||
agent_private_diary_str=agent.format_private_diary_for_prompt(),
|
||||
)
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"No messages generated for {power_name} in round {round_number}")
|
||||
return False
|
||||
|
||||
# Process and send each message
|
||||
messages_sent = 0
|
||||
for message in messages:
|
||||
success = await _send_negotiation_message(client, agent, game_history, message, power_name)
|
||||
if success:
|
||||
messages_sent += 1
|
||||
|
||||
logger.info(f"Sent {messages_sent}/{len(messages)} messages for {power_name}")
|
||||
return messages_sent > 0
|
||||
|
||||
|
||||
async def _send_negotiation_message(
|
||||
client: WebSocketDiplomacyClient,
|
||||
agent: DiplomacyAgent,
|
||||
game_history: GameHistory,
|
||||
message: Dict,
|
||||
power_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a single negotiation message via WebSocket.
|
||||
|
||||
Args:
|
||||
client: WebSocket diplomacy client
|
||||
agent: The bot's AI agent
|
||||
game_history: Game history tracker
|
||||
message: Message dictionary with content and metadata
|
||||
power_name: Name of the sending power
|
||||
|
||||
Returns:
|
||||
True if message was sent successfully, False otherwise
|
||||
"""
|
||||
# Validate message structure
|
||||
if not isinstance(message, dict) or "content" not in message:
|
||||
logger.warning(f"Invalid message format from {power_name}: {message}")
|
||||
return False
|
||||
|
||||
content = message.get("content", "").strip()
|
||||
if not content:
|
||||
logger.debug(f"Empty message content from {power_name}, skipping")
|
||||
return False
|
||||
|
||||
# Determine recipient
|
||||
recipient = GLOBAL # Default to global
|
||||
if message.get("message_type") == "private":
|
||||
recipient = message.get("recipient", GLOBAL)
|
||||
# Validate recipient is a valid power
|
||||
if recipient not in client.powers and recipient != GLOBAL:
|
||||
logger.warning(f"Invalid recipient '{recipient}' from {power_name}, sending globally")
|
||||
recipient = GLOBAL
|
||||
|
||||
# Send the message via WebSocket
|
||||
await client.send_message(
|
||||
sender=power_name,
|
||||
recipient=recipient,
|
||||
message=content,
|
||||
phase=client.get_current_short_phase(),
|
||||
)
|
||||
|
||||
# Add to game history
|
||||
game_history.add_message(
|
||||
phase_name=client.get_current_short_phase(),
|
||||
sender=power_name,
|
||||
recipient=recipient,
|
||||
message_content=content,
|
||||
)
|
||||
|
||||
# Add to agent's journal
|
||||
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
||||
agent.add_journal_entry(f"Sent message {journal_recipient} in {client.get_current_short_phase()}: {content[:100]}...")
|
||||
|
||||
logger.info(f"[{power_name} -> {recipient}] {content[:100]}...")
|
||||
return True
|
||||
|
||||
|
||||
async def should_participate_in_negotiations(
|
||||
client: "WebSocketDiplomacyClient",
|
||||
agent: "DiplomacyAgent",
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if this bot should participate in negotiations.
|
||||
|
||||
Args:
|
||||
client: WebSocket diplomacy client
|
||||
agent: The bot's AI agent
|
||||
|
||||
Returns:
|
||||
True if the bot should participate in negotiations
|
||||
"""
|
||||
power_name = agent.power_name
|
||||
|
||||
# Don't negotiate if eliminated
|
||||
if client.get_power(power_name).is_eliminated():
|
||||
return False
|
||||
|
||||
# Don't negotiate if no orderable locations
|
||||
possible_orders = gather_possible_orders(client.game, power_name)
|
||||
if not possible_orders:
|
||||
return False
|
||||
|
||||
# Only negotiate during movement phases
|
||||
current_phase = client.get_current_short_phase()
|
||||
if not current_phase.endswith("M"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_negotiation_delay(round_number: int, total_rounds: int) -> float:
|
||||
"""
|
||||
Calculate delay between negotiation rounds to allow message processing.
|
||||
|
||||
Args:
|
||||
round_number: Current round number (1-indexed)
|
||||
total_rounds: Total number of rounds
|
||||
|
||||
Returns:
|
||||
Delay in seconds
|
||||
"""
|
||||
# Longer delay in early rounds to allow more strategic messaging
|
||||
base_delay = 10.0 # Base delay between rounds
|
||||
|
||||
if round_number == 1:
|
||||
return base_delay * 1.5 # Extra time for first round
|
||||
elif round_number == total_rounds:
|
||||
return base_delay * 0.5 # Less time for final round
|
||||
else:
|
||||
return base_delay
|
||||
|
||||
|
||||
async def analyze_recent_messages_for_targeting(
|
||||
client: "WebSocketDiplomacyClient",
|
||||
power_name: str,
|
||||
max_messages: int = 20,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Analyze recent messages to identify which powers should be prioritized for negotiations.
|
||||
|
||||
Args:
|
||||
client: WebSocket diplomacy client
|
||||
power_name: Name of the analyzing power
|
||||
max_messages: Maximum number of recent messages to analyze
|
||||
|
||||
Returns:
|
||||
List of power names in order of priority for messaging
|
||||
"""
|
||||
# Get recent messages from current phase
|
||||
recent_messages = await client.get_recent_messages(limit=max_messages)
|
||||
|
||||
# Track who has been active and who has messaged us
|
||||
message_activity = {}
|
||||
direct_messages_to_us = {}
|
||||
|
||||
for message in recent_messages:
|
||||
sender = message.sender
|
||||
recipient = message.recipient
|
||||
|
||||
# Track general activity
|
||||
if sender != power_name:
|
||||
message_activity[sender] = message_activity.get(sender, 0) + 1
|
||||
|
||||
# Track direct messages to us
|
||||
if recipient == power_name and sender != power_name:
|
||||
direct_messages_to_us[sender] = direct_messages_to_us.get(sender, 0) + 1
|
||||
|
||||
# Get all active powers
|
||||
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated() and p_name != power_name]
|
||||
|
||||
# Prioritize based on: 1) Powers that messaged us directly, 2) Most active powers
|
||||
priority_list = []
|
||||
|
||||
# First, add powers that sent us direct messages (sorted by count)
|
||||
direct_senders = sorted(direct_messages_to_us.items(), key=lambda x: x[1], reverse=True)
|
||||
for sender, _ in direct_senders:
|
||||
if sender in active_powers:
|
||||
priority_list.append(sender)
|
||||
|
||||
# Then add other active powers (sorted by activity)
|
||||
remaining_powers = [p for p in active_powers if p not in priority_list]
|
||||
activity_sorted = sorted(remaining_powers, key=lambda p: message_activity.get(p, 0), reverse=True)
|
||||
priority_list.extend(activity_sorted)
|
||||
|
||||
logger.debug(f"Message targeting priority for {power_name}: {priority_list}")
|
||||
return priority_list
|
||||
|
||||
|
||||
async def conduct_strategic_negotiation_round(
|
||||
client: "WebSocketDiplomacyClient",
|
||||
agent: "DiplomacyAgent",
|
||||
game_history: "GameHistory",
|
||||
model_error_stats: Dict[str, Dict[str, int]],
|
||||
log_file_path: str,
|
||||
round_number: int,
|
||||
max_rounds: int = 3,
|
||||
) -> bool:
|
||||
"""
|
||||
Conduct a single negotiation round with strategic message targeting.
|
||||
|
||||
This function analyzes recent message activity to determine which powers
|
||||
to prioritize for messaging in this round.
|
||||
|
||||
Args:
|
||||
client: WebSocket diplomacy client
|
||||
agent: The bot's AI agent
|
||||
game_history: Game history tracker
|
||||
model_error_stats: Error statistics tracking
|
||||
log_file_path: Path for logging
|
||||
round_number: Current round number (1-indexed)
|
||||
max_rounds: Maximum number of negotiation rounds
|
||||
|
||||
Returns:
|
||||
True if messages were sent successfully, False otherwise
|
||||
"""
|
||||
power_name = agent.power_name
|
||||
|
||||
# Analyze recent messages to prioritize targets
|
||||
priority_targets = await analyze_recent_messages_for_targeting(client, power_name)
|
||||
|
||||
# Limit to top 3-4 targets in later rounds to focus conversations
|
||||
if round_number > 1:
|
||||
priority_targets = priority_targets[: min(4, len(priority_targets))]
|
||||
|
||||
logger.info(f"Round {round_number} targets for {power_name}: {priority_targets}")
|
||||
|
||||
# Conduct negotiation with prioritized targets
|
||||
return await conduct_single_bot_negotiation(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=model_error_stats,
|
||||
log_file_path=log_file_path,
|
||||
max_rounds=max_rounds,
|
||||
round_number=round_number,
|
||||
prioritize_targets=priority_targets,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue