mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
WIP: Exposing errors that cause bot to fail
Working to remove all the lower level try except statements that catch barely, and don't offer any information to top level about what went wrong. Better to throw exceptions and let the consumers catch if needed.
This commit is contained in:
parent
f37c344213
commit
c09d5bfe7e
13 changed files with 1013 additions and 400 deletions
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
|
|
@ -105,7 +105,7 @@
|
||||||
"--game-id",
|
"--game-id",
|
||||||
"${input:gameId}",
|
"${input:gameId}",
|
||||||
"--log-level",
|
"--log-level",
|
||||||
"INFO"
|
"DEBUG"
|
||||||
],
|
],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true,
|
"justMyCode": true,
|
||||||
|
|
@ -138,6 +138,7 @@
|
||||||
"type": "pickString",
|
"type": "pickString",
|
||||||
"description": "Select the AI model to use",
|
"description": "Select the AI model to use",
|
||||||
"options": [
|
"options": [
|
||||||
|
"gemini-2.5-flash-lite-preview-06-17",
|
||||||
"gpt-3.5-turbo",
|
"gpt-3.5-turbo",
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"claude-3-haiku",
|
"claude-3-haiku",
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import re
|
||||||
import json_repair
|
import json_repair
|
||||||
import json5
|
import json5
|
||||||
|
|
||||||
|
from ai_diplomacy.game_history import GameHistory
|
||||||
from diplomacy.engine.game import Game # More forgiving JSON parser
|
from diplomacy.engine.game import Game # More forgiving JSON parser
|
||||||
|
|
||||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||||
|
|
@ -1150,6 +1151,9 @@ class DiplomacyAgent:
|
||||||
game_history: "GameHistory",
|
game_history: "GameHistory",
|
||||||
log_file_path: str | Path,
|
log_file_path: str | Path,
|
||||||
):
|
):
|
||||||
|
assert GameHistory is not None, (
|
||||||
|
"GameHistory object must be initialized before calling analyze_phase_and_update_state"
|
||||||
|
)
|
||||||
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
|
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
|
||||||
# Use self.power_name internally
|
# Use self.power_name internally
|
||||||
power_name = self.power_name
|
power_name = self.power_name
|
||||||
|
|
@ -1169,7 +1173,7 @@ class DiplomacyAgent:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get previous phase safely from history
|
# Get previous phase safely from history
|
||||||
if not game_history or not game_history.phases:
|
if not game_history.phases:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update."
|
f"[{power_name}] No game history available to analyze for {game.current_short_phase}. Skipping state update."
|
||||||
)
|
)
|
||||||
|
|
@ -1254,9 +1258,8 @@ class DiplomacyAgent:
|
||||||
log_entry_success = "FALSE" # Default
|
log_entry_success = "FALSE" # Default
|
||||||
update_data = None # Initialize
|
update_data = None # Initialize
|
||||||
|
|
||||||
if (
|
# Check if response is not None and not just whitespace
|
||||||
response is not None and response.strip()
|
if response is not None and response.strip():
|
||||||
): # Check if response is not None and not just whitespace
|
|
||||||
try:
|
try:
|
||||||
update_data = self._extract_json_from_text(response)
|
update_data = self._extract_json_from_text(response)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
|
||||||
from diplomacy.engine.message import GLOBAL
|
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
from .utils import (
|
from .utils import (
|
||||||
load_prompt,
|
load_prompt,
|
||||||
|
|
@ -25,7 +24,6 @@ from .utils import (
|
||||||
|
|
||||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||||
# from .agent import DiplomacyAgent
|
# from .agent import DiplomacyAgent
|
||||||
from .possible_order_context import generate_rich_order_context
|
|
||||||
from .prompt_constructor import (
|
from .prompt_constructor import (
|
||||||
construct_order_generation_prompt,
|
construct_order_generation_prompt,
|
||||||
build_context_prompt,
|
build_context_prompt,
|
||||||
|
|
@ -113,7 +111,6 @@ class BaseModelClient:
|
||||||
raw_response = await run_llm_and_log(
|
raw_response = await run_llm_and_log(
|
||||||
client=self,
|
client=self,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
log_file_path=log_file_path,
|
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=phase,
|
phase=phase,
|
||||||
response_type="order", # Context for run_llm_and_log's own error logging
|
response_type="order", # Context for run_llm_and_log's own error logging
|
||||||
|
|
@ -553,7 +550,6 @@ class BaseModelClient:
|
||||||
raw_response = await run_llm_and_log(
|
raw_response = await run_llm_and_log(
|
||||||
client=self,
|
client=self,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
log_file_path=log_file_path,
|
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase, # Use game_phase for logging
|
phase=game_phase, # Use game_phase for logging
|
||||||
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
||||||
|
|
@ -604,7 +600,6 @@ class BaseModelClient:
|
||||||
raw_response = await run_llm_and_log(
|
raw_response = await run_llm_and_log(
|
||||||
client=self,
|
client=self,
|
||||||
prompt=raw_input_prompt,
|
prompt=raw_input_prompt,
|
||||||
log_file_path=log_file_path,
|
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase,
|
phase=game_phase,
|
||||||
response_type="negotiation", # For run_llm_and_log's internal context
|
response_type="negotiation", # For run_llm_and_log's internal context
|
||||||
|
|
@ -829,7 +824,6 @@ class BaseModelClient:
|
||||||
raw_plan_response = await run_llm_and_log(
|
raw_plan_response = await run_llm_and_log(
|
||||||
client=self, # Pass self (the client instance)
|
client=self, # Pass self (the client instance)
|
||||||
prompt=full_prompt,
|
prompt=full_prompt,
|
||||||
log_file_path=log_file_path,
|
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type="plan_generation", # More specific type for run_llm_and_log context
|
response_type="plan_generation", # More specific type for run_llm_and_log context
|
||||||
|
|
@ -1328,4 +1322,3 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
||||||
):
|
):
|
||||||
visible.append(msg)
|
visible.append(msg)
|
||||||
return visible # already in chronological order if appended that way
|
return visible # already in chronological order if appended that way
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from .utils import run_llm_and_log, log_llm_response, load_prompt
|
from .utils import run_llm_and_log, log_llm_response
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
|
|
@ -12,6 +12,25 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||||
|
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
if prompts_dir:
|
||||||
|
filepath = os.path.join(prompts_dir, filename)
|
||||||
|
else:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||||
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def run_diary_consolidation(
|
async def run_diary_consolidation(
|
||||||
agent: "DiplomacyAgent",
|
agent: "DiplomacyAgent",
|
||||||
game: "Game",
|
game: "Game",
|
||||||
|
|
@ -23,19 +42,32 @@ async def run_diary_consolidation(
|
||||||
Consolidate older diary entries while keeping recent ones.
|
Consolidate older diary entries while keeping recent ones.
|
||||||
This is the logic moved from the DiplomacyAgent class.
|
This is the logic moved from the DiplomacyAgent class.
|
||||||
"""
|
"""
|
||||||
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
logger.info(
|
||||||
|
f"[{agent.power_name}] CONSOLIDATION START — "
|
||||||
|
f"{len(agent.full_private_diary)} total full entries"
|
||||||
|
)
|
||||||
|
|
||||||
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
full_entries = [
|
||||||
|
e
|
||||||
|
for e in agent.full_private_diary
|
||||||
|
if not e.startswith("[CONSOLIDATED HISTORY]")
|
||||||
|
]
|
||||||
|
|
||||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
logger.info(
|
||||||
|
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
||||||
|
"skipping consolidation"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||||
if not match:
|
if not match:
|
||||||
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
logger.error(
|
||||||
|
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
||||||
|
"aborting consolidation"
|
||||||
|
)
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -46,19 +78,37 @@ async def run_diary_consolidation(
|
||||||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||||
return int(m.group(1)) if m else None
|
return int(m.group(1)) if m else None
|
||||||
|
|
||||||
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
entries_to_summarize = [
|
||||||
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
e
|
||||||
|
for e in full_entries
|
||||||
|
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
||||||
|
]
|
||||||
|
entries_to_keep = [
|
||||||
|
e
|
||||||
|
for e in full_entries
|
||||||
|
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
||||||
|
]
|
||||||
|
|
||||||
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
logger.info(
|
||||||
|
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
||||||
|
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
||||||
|
)
|
||||||
|
|
||||||
if not entries_to_summarize:
|
if not entries_to_summarize:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
logger.warning(
|
||||||
|
f"[{agent.power_name}] No eligible entries to summarise; "
|
||||||
|
"context diary left unchanged"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_template = load_prompt("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
prompt_template = _load_prompt_file(
|
||||||
|
"diary_consolidation_prompt.txt", prompts_dir=prompts_dir
|
||||||
|
)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
logger.error(
|
||||||
|
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
|
|
@ -87,14 +137,23 @@ async def run_diary_consolidation(
|
||||||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||||
success_flag = "TRUE"
|
success_flag = "TRUE"
|
||||||
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
logger.info(
|
||||||
|
f"[{agent.power_name}] Consolidation complete — "
|
||||||
|
f"{len(agent.private_diary)} context entries now"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
logger.error(
|
||||||
|
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
model_name=(
|
||||||
|
consolidation_client.model_name
|
||||||
|
if consolidation_client is not None
|
||||||
|
else agent.client.model_name
|
||||||
|
),
|
||||||
power_name=agent.power_name,
|
power_name=agent.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type="diary_consolidation",
|
response_type="diary_consolidation",
|
||||||
|
|
@ -102,3 +161,4 @@ async def run_diary_consolidation(
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_flag,
|
success=success_flag,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,6 @@ class GameHistory:
|
||||||
return {}
|
return {}
|
||||||
return self.phases[-1].plans
|
return self.phases[-1].plans
|
||||||
|
|
||||||
# NEW METHOD
|
|
||||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||||
current_phase: Optional[Phase] = None
|
current_phase: Optional[Phase] = None
|
||||||
for phase_obj in self.phases:
|
for phase_obj in self.phases:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from os.path import exists
|
import datetime
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -6,16 +6,16 @@ import warnings
|
||||||
|
|
||||||
class Configuration(BaseSettings):
|
class Configuration(BaseSettings):
|
||||||
DEBUG: bool = False
|
DEBUG: bool = False
|
||||||
log_file_path: Path = Path("./logs/logs.txt")
|
log_file_path: Path
|
||||||
DEEPSEEK_API_KEY: str | None = None
|
DEEPSEEK_API_KEY: str | None = None
|
||||||
OPENAI_API_KEY: str | None = None
|
OPENAI_API_KEY: str | None = None
|
||||||
ANTHROPIC_API_KEY: str | None = None
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
GEMINI_API_KEY: str | None = None
|
GEMINI_API_KEY: str | None = None
|
||||||
OPENROUTER_API_KEY: str | None = None
|
OPENROUTER_API_KEY: str | None = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, power_name, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
|
||||||
# Make the path absolute, gets rid of weirdness of calling this in different places
|
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||||
self.log_file_path = self.log_file_path.resolve()
|
self.log_file_path = self.log_file_path.resolve()
|
||||||
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -42,11 +42,7 @@ class Configuration(BaseSettings):
|
||||||
"""Override to check for empty API keys at access time"""
|
"""Override to check for empty API keys at access time"""
|
||||||
value = super().__getattribute__(name)
|
value = super().__getattribute__(name)
|
||||||
|
|
||||||
if name.endswith("_KEY") and (
|
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||||
not value or (isinstance(value, str) and len(value) == 0)
|
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"API key '{name}' is not set or is empty. Please configure it before use."
|
|
||||||
)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
|
||||||
|
|
@ -30,12 +30,9 @@ from websocket_diplomacy_client import (
|
||||||
from models import (
|
from models import (
|
||||||
SignInRequest,
|
SignInRequest,
|
||||||
CreateGameRequest,
|
CreateGameRequest,
|
||||||
JoinGameRequest,
|
|
||||||
SetOrdersRequest,
|
SetOrdersRequest,
|
||||||
ListGamesRequest,
|
ListGamesRequest,
|
||||||
ProcessGameRequest,
|
|
||||||
serialize_message,
|
serialize_message,
|
||||||
parse_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -64,9 +61,7 @@ logging.getLogger("root").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run a Diplomacy game simulation with WebSocket server connection.")
|
||||||
description="Run a Diplomacy game simulation with WebSocket server connection."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hostname",
|
"--hostname",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
@ -137,9 +132,7 @@ def parse_arguments():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
async def join_powers_for_testing(
|
async def join_powers_for_testing(client: WebSocketDiplomacyClient, power_model_map: dict):
|
||||||
client: WebSocketDiplomacyClient, power_model_map: dict
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Join multiple powers in the same game for testing purposes.
|
Join multiple powers in the same game for testing purposes.
|
||||||
This simulates having multiple AI players in one game.
|
This simulates having multiple AI players in one game.
|
||||||
|
|
@ -150,17 +143,13 @@ async def join_powers_for_testing(
|
||||||
for power_name in power_names[1:]:
|
for power_name in power_names[1:]:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Attempting to join power {power_name}")
|
logger.info(f"Attempting to join power {power_name}")
|
||||||
await client.channel.join_game(
|
await client.channel.join_game(game_id=client.game_id, power_name=power_name)
|
||||||
game_id=client.game_id, power_name=power_name
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully joined {power_name}")
|
logger.info(f"Successfully joined {power_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not join {power_name}: {e}")
|
logger.warning(f"Could not join {power_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def create_or_join_game(
|
async def create_or_join_game(client: WebSocketDiplomacyClient, args, power_model_map: dict):
|
||||||
client: WebSocketDiplomacyClient, args, power_model_map: dict
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Create a new game or join an existing one based on arguments.
|
Create a new game or join an existing one based on arguments.
|
||||||
"""
|
"""
|
||||||
|
|
@ -171,9 +160,7 @@ async def create_or_join_game(
|
||||||
# List available games first to see what's available
|
# List available games first to see what's available
|
||||||
try:
|
try:
|
||||||
games = await client.list_games()
|
games = await client.list_games()
|
||||||
logger.info(
|
logger.info(f"Available games: {[g.get('game_id', 'unknown') for g in games]}")
|
||||||
f"Available games: {[g.get('game_id', 'unknown') for g in games]}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not list games: {e}")
|
logger.warning(f"Could not list games: {e}")
|
||||||
|
|
||||||
|
|
@ -189,19 +176,13 @@ async def create_or_join_game(
|
||||||
logger.info("Creating new game")
|
logger.info("Creating new game")
|
||||||
|
|
||||||
# Get the first power to control
|
# Get the first power to control
|
||||||
first_power = (
|
first_power = list(power_model_map.keys())[0] if not args.create_multi_power_game else None
|
||||||
list(power_model_map.keys())[0]
|
|
||||||
if not args.create_multi_power_game
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
game = await client.create_game(
|
game = await client.create_game(
|
||||||
map_name="standard",
|
map_name="standard",
|
||||||
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
||||||
power_name=first_power,
|
power_name=first_power,
|
||||||
n_controls=7
|
n_controls=7 if not args.create_multi_power_game else 1, # Lower requirement for testing
|
||||||
if not args.create_multi_power_game
|
|
||||||
else 1, # Lower requirement for testing
|
|
||||||
deadline=None, # No time pressure for AI testing
|
deadline=None, # No time pressure for AI testing
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -223,11 +204,7 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
|
||||||
logger.info("=== Typed Message Examples ===")
|
logger.info("=== Typed Message Examples ===")
|
||||||
|
|
||||||
# Example 1: Authentication Request
|
# Example 1: Authentication Request
|
||||||
sign_in_request = SignInRequest(
|
sign_in_request = SignInRequest(request_id=str(uuid.uuid4()), username=args.username, password=args.password)
|
||||||
request_id=str(uuid.uuid4()),
|
|
||||||
username=args.username,
|
|
||||||
password=args.password
|
|
||||||
)
|
|
||||||
logger.info(f"SignIn request: {serialize_message(sign_in_request)}")
|
logger.info(f"SignIn request: {serialize_message(sign_in_request)}")
|
||||||
|
|
||||||
# Example 2: Create Game Request
|
# Example 2: Create Game Request
|
||||||
|
|
@ -237,26 +214,22 @@ async def demonstrate_typed_messages(client: WebSocketDiplomacyClient, args):
|
||||||
map_name="standard",
|
map_name="standard",
|
||||||
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
rules=["NO_PRESS", "IGNORE_ERRORS", "POWER_CHOICE"],
|
||||||
n_controls=1,
|
n_controls=1,
|
||||||
power_name="FRANCE"
|
power_name="FRANCE",
|
||||||
)
|
)
|
||||||
logger.info(f"CreateGame request: {serialize_message(create_game_request)}")
|
logger.info(f"CreateGame request: {serialize_message(create_game_request)}")
|
||||||
|
|
||||||
# Example 3: List Games Request
|
# Example 3: List Games Request
|
||||||
list_games_request = ListGamesRequest(
|
list_games_request = ListGamesRequest(request_id=str(uuid.uuid4()), token="example_auth_token_here", include_protected=False)
|
||||||
request_id=str(uuid.uuid4()),
|
|
||||||
token="example_auth_token_here",
|
|
||||||
include_protected=False
|
|
||||||
)
|
|
||||||
logger.info(f"ListGames request: {serialize_message(list_games_request)}")
|
logger.info(f"ListGames request: {serialize_message(list_games_request)}")
|
||||||
|
|
||||||
# Example 4: Set Orders Request (would be used in game loop)
|
# Example 4: Set Orders Request (would be used in game loop)
|
||||||
if hasattr(client, 'game_id') and hasattr(client, 'token'):
|
if hasattr(client, "game_id") and hasattr(client, "token"):
|
||||||
set_orders_request = SetOrdersRequest(
|
set_orders_request = SetOrdersRequest(
|
||||||
request_id=str(uuid.uuid4()),
|
request_id=str(uuid.uuid4()),
|
||||||
token=client.token or "example_token",
|
token=client.token or "example_token",
|
||||||
game_id=client.game_id or "example_game_id",
|
game_id=client.game_id or "example_game_id",
|
||||||
game_role="FRANCE",
|
game_role="FRANCE",
|
||||||
orders=["A PAR H", "F BRE H", "A MAR H"]
|
orders=["A PAR H", "F BRE H", "A MAR H"],
|
||||||
)
|
)
|
||||||
logger.info(f"SetOrders request: {serialize_message(set_orders_request)}")
|
logger.info(f"SetOrders request: {serialize_message(set_orders_request)}")
|
||||||
|
|
||||||
|
|
@ -282,9 +255,7 @@ async def main():
|
||||||
logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs")
|
logger.info("Starting WebSocket-based Diplomacy game with multiple LLMs")
|
||||||
start_whole = time.time()
|
start_whole = time.time()
|
||||||
|
|
||||||
model_error_stats = defaultdict(
|
model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
|
||||||
lambda: {"conversation_errors": 0, "order_decoding_errors": 0}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the result folder based on a timestamp
|
# Determine the result folder based on a timestamp
|
||||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
@ -306,9 +277,7 @@ async def main():
|
||||||
|
|
||||||
# File paths
|
# File paths
|
||||||
manifesto_path = f"{result_folder}/game_manifesto.txt"
|
manifesto_path = f"{result_folder}/game_manifesto.txt"
|
||||||
game_file_path = (
|
game_file_path = args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
|
||||||
args.output if args.output else f"{result_folder}/lmvsgame_websocket.json"
|
|
||||||
)
|
|
||||||
overview_file_path = f"{result_folder}/overview.jsonl"
|
overview_file_path = f"{result_folder}/overview.jsonl"
|
||||||
llm_log_file_path = f"{result_folder}/llm_responses.csv"
|
llm_log_file_path = f"{result_folder}/llm_responses.csv"
|
||||||
|
|
||||||
|
|
@ -325,9 +294,7 @@ async def main():
|
||||||
]
|
]
|
||||||
provided_models = [name.strip() for name in args.models.split(",")]
|
provided_models = [name.strip() for name in args.models.split(",")]
|
||||||
if len(provided_models) != len(powers_order):
|
if len(provided_models) != len(powers_order):
|
||||||
logger.error(
|
logger.error(f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting.")
|
||||||
f"Expected {len(powers_order)} models for --models but got {len(provided_models)}. Exiting."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
power_model_map = dict(zip(powers_order, provided_models))
|
power_model_map = dict(zip(powers_order, provided_models))
|
||||||
else:
|
else:
|
||||||
|
|
@ -377,32 +344,20 @@ async def main():
|
||||||
client_obj = load_model_client(model_id)
|
client_obj = load_model_client(model_id)
|
||||||
agent = DiplomacyAgent(power_name=power_name, client=client_obj)
|
agent = DiplomacyAgent(power_name=power_name, client=client_obj)
|
||||||
agents[power_name] = agent
|
agents[power_name] = agent
|
||||||
logger.info(
|
logger.info(f"Preparing initialization task for {power_name} with model {model_id}")
|
||||||
f"Preparing initialization task for {power_name} with model {model_id}"
|
initialization_tasks.append(initialize_agent_state_ext(agent, client.game, game_history, llm_log_file_path))
|
||||||
)
|
|
||||||
initialization_tasks.append(
|
|
||||||
initialize_agent_state_ext(
|
|
||||||
agent, client.game, game_history, llm_log_file_path
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to create agent for {power_name} with model {model_id}: {e}",
|
f"Failed to create agent for {power_name} with model {model_id}: {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"Skipping agent initialization for {power_name} (no model or eliminated)")
|
||||||
f"Skipping agent initialization for {power_name} (no model or eliminated)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run initializations concurrently
|
# Run initializations concurrently
|
||||||
if initialization_tasks:
|
if initialization_tasks:
|
||||||
logger.info(
|
logger.info(f"Running {len(initialization_tasks)} agent initializations concurrently...")
|
||||||
f"Running {len(initialization_tasks)} agent initializations concurrently..."
|
initialization_results = await asyncio.gather(*initialization_tasks, return_exceptions=True)
|
||||||
)
|
|
||||||
initialization_results = await asyncio.gather(
|
|
||||||
*initialization_tasks, return_exceptions=True
|
|
||||||
)
|
|
||||||
|
|
||||||
initialized_powers = list(agents.keys())
|
initialized_powers = list(agents.keys())
|
||||||
for i, result in enumerate(initialization_results):
|
for i, result in enumerate(initialization_results):
|
||||||
|
|
@ -414,9 +369,7 @@ async def main():
|
||||||
exc_info=result,
|
exc_info=result,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||||
f"Successfully initialized agent state for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Main game loop
|
# Main game loop
|
||||||
all_phase_relationships = {}
|
all_phase_relationships = {}
|
||||||
|
|
@ -433,9 +386,7 @@ async def main():
|
||||||
game_history.add_phase(current_phase)
|
game_history.add_phase(current_phase)
|
||||||
current_short_phase = client.get_current_short_phase()
|
current_short_phase = client.get_current_short_phase()
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
|
||||||
f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prevent unbounded simulation based on year
|
# Prevent unbounded simulation based on year
|
||||||
year_str = current_phase[1:5]
|
year_str = current_phase[1:5]
|
||||||
|
|
@ -447,9 +398,7 @@ async def main():
|
||||||
# Negotiations for movement phases
|
# Negotiations for movement phases
|
||||||
if client.get_current_short_phase().endswith("M"):
|
if client.get_current_short_phase().endswith("M"):
|
||||||
if args.num_negotiation_rounds > 0:
|
if args.num_negotiation_rounds > 0:
|
||||||
logger.info(
|
logger.info(f"Running {args.num_negotiation_rounds} rounds of negotiations...")
|
||||||
f"Running {args.num_negotiation_rounds} rounds of negotiations..."
|
|
||||||
)
|
|
||||||
game_history = await conduct_negotiations(
|
game_history = await conduct_negotiations(
|
||||||
client.game, # Pass the NetworkGame object
|
client.game, # Pass the NetworkGame object
|
||||||
agents,
|
agents,
|
||||||
|
|
@ -459,9 +408,7 @@ async def main():
|
||||||
log_file_path=llm_log_file_path,
|
log_file_path=llm_log_file_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info("Skipping negotiation phase as num_negotiation_rounds=0")
|
||||||
"Skipping negotiation phase as num_negotiation_rounds=0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Planning phase (if enabled)
|
# Planning phase (if enabled)
|
||||||
if args.planning_phase:
|
if args.planning_phase:
|
||||||
|
|
@ -475,29 +422,19 @@ async def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate negotiation diary entries
|
# Generate negotiation diary entries
|
||||||
logger.info(
|
logger.info(f"Generating negotiation diary entries for phase {current_short_phase}...")
|
||||||
f"Generating negotiation diary entries for phase {current_short_phase}..."
|
active_powers_for_neg_diary = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
|
||||||
)
|
|
||||||
active_powers_for_neg_diary = [
|
|
||||||
p for p in agents.keys() if not client.get_power(p).is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
neg_diary_tasks = []
|
neg_diary_tasks = []
|
||||||
for power_name, agent in agents.items():
|
for power_name, agent in agents.items():
|
||||||
if not client.get_power(power_name).is_eliminated():
|
if not client.get_power(power_name).is_eliminated():
|
||||||
neg_diary_tasks.append(
|
neg_diary_tasks.append(agent.generate_negotiation_diary_entry(client.game, game_history, llm_log_file_path))
|
||||||
agent.generate_negotiation_diary_entry(
|
|
||||||
client.game, game_history, llm_log_file_path
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if neg_diary_tasks:
|
if neg_diary_tasks:
|
||||||
await asyncio.gather(*neg_diary_tasks, return_exceptions=True)
|
await asyncio.gather(*neg_diary_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# AI Decision Making: Get orders for each controlled power
|
# AI Decision Making: Get orders for each controlled power
|
||||||
logger.info("Getting orders from agents...")
|
logger.info("Getting orders from agents...")
|
||||||
active_powers_for_orders = [
|
active_powers_for_orders = [p for p in agents.keys() if not client.get_power(p).is_eliminated()]
|
||||||
p for p in agents.keys() if not client.get_power(p).is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
order_tasks = []
|
order_tasks = []
|
||||||
order_power_names = []
|
order_power_names = []
|
||||||
|
|
@ -505,35 +442,23 @@ async def main():
|
||||||
|
|
||||||
for power_name, agent in agents.items():
|
for power_name, agent in agents.items():
|
||||||
if client.get_power(power_name).is_eliminated():
|
if client.get_power(power_name).is_eliminated():
|
||||||
logger.debug(
|
logger.debug(f"Skipping order generation for eliminated power {power_name}.")
|
||||||
f"Skipping order generation for eliminated power {power_name}."
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Diagnostic logging
|
# Diagnostic logging
|
||||||
logger.info(
|
logger.info(f"--- Diagnostic Log for {power_name} in phase {current_phase} ---")
|
||||||
f"--- Diagnostic Log for {power_name} in phase {current_phase} ---"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
orderable_locs = client.get_orderable_locations(power_name)
|
orderable_locs = client.get_orderable_locations(power_name)
|
||||||
logger.info(
|
logger.info(f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}")
|
||||||
f"[{power_name}][{current_phase}] Orderable locations: {orderable_locs}"
|
|
||||||
)
|
|
||||||
actual_units = client.get_units(power_name)
|
actual_units = client.get_units(power_name)
|
||||||
logger.info(
|
logger.info(f"[{power_name}][{current_phase}] Actual units: {actual_units}")
|
||||||
f"[{power_name}][{current_phase}] Actual units: {actual_units}"
|
|
||||||
)
|
|
||||||
except Exception as e_diag:
|
except Exception as e_diag:
|
||||||
logger.error(
|
logger.error(f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}")
|
||||||
f"[{power_name}][{current_phase}] Error during diagnostic logging: {e_diag}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate possible orders
|
# Calculate possible orders
|
||||||
possible_orders = gather_possible_orders(client.game, power_name)
|
possible_orders = gather_possible_orders(client.game, power_name)
|
||||||
if not possible_orders:
|
if not possible_orders:
|
||||||
logger.debug(
|
logger.debug(f"No orderable locations for {power_name}; submitting empty orders.")
|
||||||
f"No orderable locations for {power_name}; submitting empty orders."
|
|
||||||
)
|
|
||||||
await client.set_orders(power_name, [])
|
await client.set_orders(power_name, [])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -559,12 +484,8 @@ async def main():
|
||||||
|
|
||||||
# Run order generation concurrently
|
# Run order generation concurrently
|
||||||
if order_tasks:
|
if order_tasks:
|
||||||
logger.debug(
|
logger.debug(f"Running {len(order_tasks)} order generation tasks concurrently...")
|
||||||
f"Running {len(order_tasks)} order generation tasks concurrently..."
|
order_results = await asyncio.gather(*order_tasks, return_exceptions=True)
|
||||||
)
|
|
||||||
order_results = await asyncio.gather(
|
|
||||||
*order_tasks, return_exceptions=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
order_results = []
|
order_results = []
|
||||||
|
|
||||||
|
|
@ -580,9 +501,7 @@ async def main():
|
||||||
)
|
)
|
||||||
await client.set_orders(p_name, [])
|
await client.set_orders(p_name, [])
|
||||||
elif result is None:
|
elif result is None:
|
||||||
logger.warning(
|
logger.warning(f"get_valid_orders returned None for {p_name}. Setting empty orders.")
|
||||||
f"get_valid_orders returned None for {p_name}. Setting empty orders."
|
|
||||||
)
|
|
||||||
await client.set_orders(p_name, [])
|
await client.set_orders(p_name, [])
|
||||||
else:
|
else:
|
||||||
orders = result
|
orders = result
|
||||||
|
|
@ -600,15 +519,11 @@ async def main():
|
||||||
# await client.send_typed_message(set_orders_msg)
|
# await client.send_typed_message(set_orders_msg)
|
||||||
|
|
||||||
await client.set_orders(p_name, orders)
|
await client.set_orders(p_name, orders)
|
||||||
logger.debug(
|
logger.debug(f"Set orders for {p_name} in {current_short_phase}: {orders}")
|
||||||
f"Set orders for {p_name} in {current_short_phase}: {orders}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate order diary entry
|
# Generate order diary entry
|
||||||
try:
|
try:
|
||||||
await agent.generate_order_diary_entry(
|
await agent.generate_order_diary_entry(client.game, orders, llm_log_file_path)
|
||||||
client.game, orders, llm_log_file_path
|
|
||||||
)
|
|
||||||
except Exception as e_diary:
|
except Exception as e_diary:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error generating order diary for {p_name}: {e_diary}",
|
f"Error generating order diary for {p_name}: {e_diary}",
|
||||||
|
|
@ -650,22 +565,13 @@ async def main():
|
||||||
# Collect relationships for this phase
|
# Collect relationships for this phase
|
||||||
current_relationships_for_phase = {}
|
current_relationships_for_phase = {}
|
||||||
for power_name, agent in agents.items():
|
for power_name, agent in agents.items():
|
||||||
if (
|
if power_name in client.powers and not client.get_power(power_name).is_eliminated():
|
||||||
power_name in client.powers
|
|
||||||
and not client.get_power(power_name).is_eliminated()
|
|
||||||
):
|
|
||||||
current_relationships_for_phase[power_name] = agent.relationships
|
current_relationships_for_phase[power_name] = agent.relationships
|
||||||
all_phase_relationships[current_short_phase] = (
|
all_phase_relationships[current_short_phase] = current_relationships_for_phase
|
||||||
current_relationships_for_phase
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate phase result diary entries
|
# Generate phase result diary entries
|
||||||
logger.info(
|
logger.info(f"Generating phase result diary entries for completed phase {current_phase}...")
|
||||||
f"Generating phase result diary entries for completed phase {current_phase}..."
|
phase_summary = getattr(client.game, "phase_summaries", {}).get(current_phase, "(Summary not generated)")
|
||||||
)
|
|
||||||
phase_summary = getattr(client.game, "phase_summaries", {}).get(
|
|
||||||
current_phase, "(Summary not generated)"
|
|
||||||
)
|
|
||||||
all_orders_this_phase = current_order_history
|
all_orders_this_phase = current_order_history
|
||||||
|
|
||||||
phase_result_diary_tasks = []
|
phase_result_diary_tasks = []
|
||||||
|
|
@ -685,16 +591,10 @@ async def main():
|
||||||
await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True)
|
await asyncio.gather(*phase_result_diary_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# State update analysis
|
# State update analysis
|
||||||
logger.info(
|
logger.info(f"Starting state update analysis for completed phase {current_phase}...")
|
||||||
f"Starting state update analysis for completed phase {current_phase}..."
|
|
||||||
)
|
|
||||||
current_board_state = client.get_state()
|
current_board_state = client.get_state()
|
||||||
|
|
||||||
active_agent_powers = [
|
active_agent_powers = [(p, power) for p, power in client.powers.items() if p in agents and not power.is_eliminated()]
|
||||||
(p, power)
|
|
||||||
for p, power in client.powers.items()
|
|
||||||
if p in agents and not power.is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
if active_agent_powers:
|
if active_agent_powers:
|
||||||
state_update_tasks = []
|
state_update_tasks = []
|
||||||
|
|
@ -781,4 +681,3 @@ async def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,17 @@ and notifications (server -> client).
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Any, Union, Literal
|
from typing import Optional, List, Dict, Any, Union, Literal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Base Message Classes
|
# Base Message Classes
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(BaseModel, ABC):
|
class BaseMessage(BaseModel, ABC):
|
||||||
"""Base class for all WebSocket messages."""
|
"""Base class for all WebSocket messages."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
@ -25,17 +27,20 @@ class BaseMessage(BaseModel, ABC):
|
||||||
|
|
||||||
class BaseRequest(BaseMessage):
|
class BaseRequest(BaseMessage):
|
||||||
"""Base class for all client -> server requests."""
|
"""Base class for all client -> server requests."""
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
re_sent: bool = False
|
re_sent: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseMessage):
|
class BaseResponse(BaseMessage):
|
||||||
"""Base class for all server -> client responses."""
|
"""Base class for all server -> client responses."""
|
||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
class BaseNotification(BaseMessage):
|
class BaseNotification(BaseMessage):
|
||||||
"""Base class for all server -> client notifications (no request_id)."""
|
"""Base class for all server -> client notifications (no request_id)."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -43,8 +48,10 @@ class BaseNotification(BaseMessage):
|
||||||
# Authentication & Connection Level Messages
|
# Authentication & Connection Level Messages
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SignInRequest(BaseRequest):
|
class SignInRequest(BaseRequest):
|
||||||
"""Client authentication request."""
|
"""Client authentication request."""
|
||||||
|
|
||||||
name: Literal["sign_in"] = "sign_in"
|
name: Literal["sign_in"] = "sign_in"
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
|
|
@ -52,6 +59,7 @@ class SignInRequest(BaseRequest):
|
||||||
|
|
||||||
class GetDaidePortRequest(BaseRequest):
|
class GetDaidePortRequest(BaseRequest):
|
||||||
"""Request DAIDE TCP port for a game."""
|
"""Request DAIDE TCP port for a game."""
|
||||||
|
|
||||||
name: Literal["get_daide_port"] = "get_daide_port"
|
name: Literal["get_daide_port"] = "get_daide_port"
|
||||||
game_id: str
|
game_id: str
|
||||||
|
|
||||||
|
|
@ -60,13 +68,16 @@ class GetDaidePortRequest(BaseRequest):
|
||||||
# Channel Level Messages (require authentication token)
|
# Channel Level Messages (require authentication token)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ChannelRequest(BaseRequest):
|
class ChannelRequest(BaseRequest):
|
||||||
"""Base class for channel-level requests that require authentication."""
|
"""Base class for channel-level requests that require authentication."""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
|
|
||||||
class CreateGameRequest(ChannelRequest):
|
class CreateGameRequest(ChannelRequest):
|
||||||
"""Create a new game."""
|
"""Create a new game."""
|
||||||
|
|
||||||
name: Literal["create_game"] = "create_game"
|
name: Literal["create_game"] = "create_game"
|
||||||
map_name: str = "standard"
|
map_name: str = "standard"
|
||||||
rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"])
|
rules: List[str] = Field(default_factory=lambda: ["NO_PRESS", "IGNORE_ERRORS"])
|
||||||
|
|
@ -78,6 +89,7 @@ class CreateGameRequest(ChannelRequest):
|
||||||
|
|
||||||
class JoinGameRequest(ChannelRequest):
|
class JoinGameRequest(ChannelRequest):
|
||||||
"""Join an existing game."""
|
"""Join an existing game."""
|
||||||
|
|
||||||
name: Literal["join_game"] = "join_game"
|
name: Literal["join_game"] = "join_game"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: Optional[str] = None
|
power_name: Optional[str] = None
|
||||||
|
|
@ -86,6 +98,7 @@ class JoinGameRequest(ChannelRequest):
|
||||||
|
|
||||||
class JoinPowersRequest(ChannelRequest):
|
class JoinPowersRequest(ChannelRequest):
|
||||||
"""Join multiple powers in a game."""
|
"""Join multiple powers in a game."""
|
||||||
|
|
||||||
name: Literal["join_powers"] = "join_powers"
|
name: Literal["join_powers"] = "join_powers"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_names: List[str]
|
power_names: List[str]
|
||||||
|
|
@ -94,6 +107,7 @@ class JoinPowersRequest(ChannelRequest):
|
||||||
|
|
||||||
class ListGamesRequest(ChannelRequest):
|
class ListGamesRequest(ChannelRequest):
|
||||||
"""List available games."""
|
"""List available games."""
|
||||||
|
|
||||||
name: Literal["list_games"] = "list_games"
|
name: Literal["list_games"] = "list_games"
|
||||||
game_id_filter: Optional[str] = None
|
game_id_filter: Optional[str] = None
|
||||||
map_name: Optional[str] = None
|
map_name: Optional[str] = None
|
||||||
|
|
@ -103,23 +117,27 @@ class ListGamesRequest(ChannelRequest):
|
||||||
|
|
||||||
class GetPlayablePowersRequest(ChannelRequest):
|
class GetPlayablePowersRequest(ChannelRequest):
|
||||||
"""Get uncontrolled powers in a game."""
|
"""Get uncontrolled powers in a game."""
|
||||||
|
|
||||||
name: Literal["get_playable_powers"] = "get_playable_powers"
|
name: Literal["get_playable_powers"] = "get_playable_powers"
|
||||||
game_id: str
|
game_id: str
|
||||||
|
|
||||||
|
|
||||||
class GetAvailableMapsRequest(ChannelRequest):
|
class GetAvailableMapsRequest(ChannelRequest):
|
||||||
"""Get list of available maps."""
|
"""Get list of available maps."""
|
||||||
|
|
||||||
name: Literal["get_available_maps"] = "get_available_maps"
|
name: Literal["get_available_maps"] = "get_available_maps"
|
||||||
|
|
||||||
|
|
||||||
class GetDummyWaitingPowersRequest(ChannelRequest):
|
class GetDummyWaitingPowersRequest(ChannelRequest):
|
||||||
"""Get AI-controllable powers (bot use)."""
|
"""Get AI-controllable powers (bot use)."""
|
||||||
|
|
||||||
name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers"
|
name: Literal["get_dummy_waiting_powers"] = "get_dummy_waiting_powers"
|
||||||
game_id: str
|
game_id: str
|
||||||
|
|
||||||
|
|
||||||
class SetGradeRequest(ChannelRequest):
|
class SetGradeRequest(ChannelRequest):
|
||||||
"""Modify user permissions."""
|
"""Modify user permissions."""
|
||||||
|
|
||||||
name: Literal["set_grade"] = "set_grade"
|
name: Literal["set_grade"] = "set_grade"
|
||||||
username: str
|
username: str
|
||||||
grade: str
|
grade: str
|
||||||
|
|
@ -127,11 +145,13 @@ class SetGradeRequest(ChannelRequest):
|
||||||
|
|
||||||
class DeleteAccountRequest(ChannelRequest):
|
class DeleteAccountRequest(ChannelRequest):
|
||||||
"""Delete user account."""
|
"""Delete user account."""
|
||||||
|
|
||||||
name: Literal["delete_account"] = "delete_account"
|
name: Literal["delete_account"] = "delete_account"
|
||||||
|
|
||||||
|
|
||||||
class LogoutRequest(ChannelRequest):
|
class LogoutRequest(ChannelRequest):
|
||||||
"""Disconnect from server."""
|
"""Disconnect from server."""
|
||||||
|
|
||||||
name: Literal["logout"] = "logout"
|
name: Literal["logout"] = "logout"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,8 +159,10 @@ class LogoutRequest(ChannelRequest):
|
||||||
# Game Level Messages (require authentication + game context)
|
# Game Level Messages (require authentication + game context)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GameRequest(ChannelRequest):
|
class GameRequest(ChannelRequest):
|
||||||
"""Base class for game-level requests."""
|
"""Base class for game-level requests."""
|
||||||
|
|
||||||
game_id: str
|
game_id: str
|
||||||
game_role: str # Power name like "ENGLAND"
|
game_role: str # Power name like "ENGLAND"
|
||||||
phase: Optional[str] = None
|
phase: Optional[str] = None
|
||||||
|
|
@ -148,18 +170,21 @@ class GameRequest(ChannelRequest):
|
||||||
|
|
||||||
class SetOrdersRequest(GameRequest):
|
class SetOrdersRequest(GameRequest):
|
||||||
"""Submit orders for a power."""
|
"""Submit orders for a power."""
|
||||||
|
|
||||||
name: Literal["set_orders"] = "set_orders"
|
name: Literal["set_orders"] = "set_orders"
|
||||||
orders: List[str]
|
orders: List[str]
|
||||||
|
|
||||||
|
|
||||||
class SetWaitFlagRequest(GameRequest):
|
class SetWaitFlagRequest(GameRequest):
|
||||||
"""Set wait flag for turn processing."""
|
"""Set wait flag for turn processing."""
|
||||||
|
|
||||||
name: Literal["set_wait_flag"] = "set_wait_flag"
|
name: Literal["set_wait_flag"] = "set_wait_flag"
|
||||||
wait: bool
|
wait: bool
|
||||||
|
|
||||||
|
|
||||||
class SendGameMessageRequest(GameRequest):
|
class SendGameMessageRequest(GameRequest):
|
||||||
"""Send diplomatic message."""
|
"""Send diplomatic message."""
|
||||||
|
|
||||||
name: Literal["send_game_message"] = "send_game_message"
|
name: Literal["send_game_message"] = "send_game_message"
|
||||||
recipient: str # Power name or "GLOBAL"
|
recipient: str # Power name or "GLOBAL"
|
||||||
message: str
|
message: str
|
||||||
|
|
@ -168,11 +193,13 @@ class SendGameMessageRequest(GameRequest):
|
||||||
|
|
||||||
class GetAllPossibleOrdersRequest(GameRequest):
|
class GetAllPossibleOrdersRequest(GameRequest):
|
||||||
"""Get legal orders for current phase."""
|
"""Get legal orders for current phase."""
|
||||||
|
|
||||||
name: Literal["get_all_possible_orders"] = "get_all_possible_orders"
|
name: Literal["get_all_possible_orders"] = "get_all_possible_orders"
|
||||||
|
|
||||||
|
|
||||||
class GetPhaseHistoryRequest(GameRequest):
|
class GetPhaseHistoryRequest(GameRequest):
|
||||||
"""Get historical game phases."""
|
"""Get historical game phases."""
|
||||||
|
|
||||||
name: Literal["get_phase_history"] = "get_phase_history"
|
name: Literal["get_phase_history"] = "get_phase_history"
|
||||||
from_phase: Optional[str] = None
|
from_phase: Optional[str] = None
|
||||||
to_phase: Optional[str] = None
|
to_phase: Optional[str] = None
|
||||||
|
|
@ -180,45 +207,53 @@ class GetPhaseHistoryRequest(GameRequest):
|
||||||
|
|
||||||
class ProcessGameRequest(GameRequest):
|
class ProcessGameRequest(GameRequest):
|
||||||
"""Force game processing (master only)."""
|
"""Force game processing (master only)."""
|
||||||
|
|
||||||
name: Literal["process_game"] = "process_game"
|
name: Literal["process_game"] = "process_game"
|
||||||
|
|
||||||
|
|
||||||
class VoteRequest(GameRequest):
|
class VoteRequest(GameRequest):
|
||||||
"""Vote for/against draw."""
|
"""Vote for/against draw."""
|
||||||
|
|
||||||
name: Literal["vote"] = "vote"
|
name: Literal["vote"] = "vote"
|
||||||
vote: Literal["yes", "no"]
|
vote: Literal["yes", "no"]
|
||||||
|
|
||||||
|
|
||||||
class SaveGameRequest(GameRequest):
|
class SaveGameRequest(GameRequest):
|
||||||
"""Export game as JSON."""
|
"""Export game as JSON."""
|
||||||
|
|
||||||
name: Literal["save_game"] = "save_game"
|
name: Literal["save_game"] = "save_game"
|
||||||
|
|
||||||
|
|
||||||
class SetGameStateRequest(GameRequest):
|
class SetGameStateRequest(GameRequest):
|
||||||
"""Modify game state (master only)."""
|
"""Modify game state (master only)."""
|
||||||
|
|
||||||
name: Literal["set_game_state"] = "set_game_state"
|
name: Literal["set_game_state"] = "set_game_state"
|
||||||
state: Dict[str, Any]
|
state: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class SetGameStatusRequest(GameRequest):
|
class SetGameStatusRequest(GameRequest):
|
||||||
"""Change game status (master only)."""
|
"""Change game status (master only)."""
|
||||||
|
|
||||||
name: Literal["set_game_status"] = "set_game_status"
|
name: Literal["set_game_status"] = "set_game_status"
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
class SetDummyPowersRequest(GameRequest):
|
class SetDummyPowersRequest(GameRequest):
|
||||||
"""Make powers AI-controlled (master only)."""
|
"""Make powers AI-controlled (master only)."""
|
||||||
|
|
||||||
name: Literal["set_dummy_powers"] = "set_dummy_powers"
|
name: Literal["set_dummy_powers"] = "set_dummy_powers"
|
||||||
power_names: List[str]
|
power_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
class DeleteGameRequest(GameRequest):
|
class DeleteGameRequest(GameRequest):
|
||||||
"""Delete game (master only)."""
|
"""Delete game (master only)."""
|
||||||
|
|
||||||
name: Literal["delete_game"] = "delete_game"
|
name: Literal["delete_game"] = "delete_game"
|
||||||
|
|
||||||
|
|
||||||
class LeaveGameRequest(GameRequest):
|
class LeaveGameRequest(GameRequest):
|
||||||
"""Leave game."""
|
"""Leave game."""
|
||||||
|
|
||||||
name: Literal["leave_game"] = "leave_game"
|
name: Literal["leave_game"] = "leave_game"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -226,13 +261,16 @@ class LeaveGameRequest(GameRequest):
|
||||||
# Response Messages (Server -> Client)
|
# Response Messages (Server -> Client)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class OkResponse(BaseResponse):
|
class OkResponse(BaseResponse):
|
||||||
"""Generic success response."""
|
"""Generic success response."""
|
||||||
|
|
||||||
name: Literal["ok"] = "ok"
|
name: Literal["ok"] = "ok"
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseResponse):
|
class ErrorResponse(BaseResponse):
|
||||||
"""Error response with error type and message."""
|
"""Error response with error type and message."""
|
||||||
|
|
||||||
name: Literal["error"] = "error"
|
name: Literal["error"] = "error"
|
||||||
error_type: str
|
error_type: str
|
||||||
message: str
|
message: str
|
||||||
|
|
@ -240,60 +278,70 @@ class ErrorResponse(BaseResponse):
|
||||||
|
|
||||||
class DataTokenResponse(BaseResponse):
|
class DataTokenResponse(BaseResponse):
|
||||||
"""Contains authentication token."""
|
"""Contains authentication token."""
|
||||||
|
|
||||||
name: Literal["data_token"] = "data_token"
|
name: Literal["data_token"] = "data_token"
|
||||||
data: str # The authentication token
|
data: str # The authentication token
|
||||||
|
|
||||||
|
|
||||||
class DataGameResponse(BaseResponse):
|
class DataGameResponse(BaseResponse):
|
||||||
"""Contains full game object."""
|
"""Contains full game object."""
|
||||||
|
|
||||||
name: Literal["data_game"] = "data_game"
|
name: Literal["data_game"] = "data_game"
|
||||||
data: Dict[str, Any] # The complete game state
|
data: Dict[str, Any] # The complete game state
|
||||||
|
|
||||||
|
|
||||||
class DataGameInfoResponse(BaseResponse):
|
class DataGameInfoResponse(BaseResponse):
|
||||||
"""Contains game metadata."""
|
"""Contains game metadata."""
|
||||||
|
|
||||||
name: Literal["data_game_info"] = "data_game_info"
|
name: Literal["data_game_info"] = "data_game_info"
|
||||||
data: Dict[str, Any]
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class DataGamesResponse(BaseResponse):
|
class DataGamesResponse(BaseResponse):
|
||||||
"""List of game information."""
|
"""List of game information."""
|
||||||
|
|
||||||
name: Literal["data_games"] = "data_games"
|
name: Literal["data_games"] = "data_games"
|
||||||
data: List[Dict[str, Any]]
|
data: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class DataMapsResponse(BaseResponse):
|
class DataMapsResponse(BaseResponse):
|
||||||
"""Available maps information."""
|
"""Available maps information."""
|
||||||
|
|
||||||
name: Literal["data_maps"] = "data_maps"
|
name: Literal["data_maps"] = "data_maps"
|
||||||
data: List[str]
|
data: List[str]
|
||||||
|
|
||||||
|
|
||||||
class DataPowerNamesResponse(BaseResponse):
|
class DataPowerNamesResponse(BaseResponse):
|
||||||
"""List of power names."""
|
"""List of power names."""
|
||||||
|
|
||||||
name: Literal["data_power_names"] = "data_power_names"
|
name: Literal["data_power_names"] = "data_power_names"
|
||||||
data: List[str]
|
data: List[str]
|
||||||
|
|
||||||
|
|
||||||
class DataPossibleOrdersResponse(BaseResponse):
|
class DataPossibleOrdersResponse(BaseResponse):
|
||||||
"""Legal orders and locations."""
|
"""Legal orders and locations."""
|
||||||
|
|
||||||
name: Literal["data_possible_orders"] = "data_possible_orders"
|
name: Literal["data_possible_orders"] = "data_possible_orders"
|
||||||
data: Dict[str, List[str]] # Location -> list of possible orders
|
data: Dict[str, List[str]] # Location -> list of possible orders
|
||||||
|
|
||||||
|
|
||||||
class DataGamePhasesResponse(BaseResponse):
|
class DataGamePhasesResponse(BaseResponse):
|
||||||
"""Historical game phases."""
|
"""Historical game phases."""
|
||||||
|
|
||||||
name: Literal["data_game_phases"] = "data_game_phases"
|
name: Literal["data_game_phases"] = "data_game_phases"
|
||||||
data: List[Dict[str, Any]]
|
data: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class DataSavedGameResponse(BaseResponse):
|
class DataSavedGameResponse(BaseResponse):
|
||||||
"""Exported game JSON."""
|
"""Exported game JSON."""
|
||||||
|
|
||||||
name: Literal["data_saved_game"] = "data_saved_game"
|
name: Literal["data_saved_game"] = "data_saved_game"
|
||||||
data: Dict[str, Any]
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class DataPortResponse(BaseResponse):
|
class DataPortResponse(BaseResponse):
|
||||||
"""DAIDE port number."""
|
"""DAIDE port number."""
|
||||||
|
|
||||||
name: Literal["data_port"] = "data_port"
|
name: Literal["data_port"] = "data_port"
|
||||||
data: int
|
data: int
|
||||||
|
|
||||||
|
|
@ -302,8 +350,10 @@ class DataPortResponse(BaseResponse):
|
||||||
# Notification Messages (Server -> Client)
|
# Notification Messages (Server -> Client)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GameProcessedNotification(BaseNotification):
|
class GameProcessedNotification(BaseNotification):
|
||||||
"""Phase completed, new orders phase."""
|
"""Phase completed, new orders phase."""
|
||||||
|
|
||||||
name: Literal["game_processed"] = "game_processed"
|
name: Literal["game_processed"] = "game_processed"
|
||||||
game_id: str
|
game_id: str
|
||||||
phase: str
|
phase: str
|
||||||
|
|
@ -312,6 +362,7 @@ class GameProcessedNotification(BaseNotification):
|
||||||
|
|
||||||
class GamePhaseUpdateNotification(BaseNotification):
|
class GamePhaseUpdateNotification(BaseNotification):
|
||||||
"""Game state changed."""
|
"""Game state changed."""
|
||||||
|
|
||||||
name: Literal["game_phase_update"] = "game_phase_update"
|
name: Literal["game_phase_update"] = "game_phase_update"
|
||||||
game_id: str
|
game_id: str
|
||||||
phase: str
|
phase: str
|
||||||
|
|
@ -320,6 +371,7 @@ class GamePhaseUpdateNotification(BaseNotification):
|
||||||
|
|
||||||
class GameStatusUpdateNotification(BaseNotification):
|
class GameStatusUpdateNotification(BaseNotification):
|
||||||
"""Game status changed (forming/active/paused/completed)."""
|
"""Game status changed (forming/active/paused/completed)."""
|
||||||
|
|
||||||
name: Literal["game_status_update"] = "game_status_update"
|
name: Literal["game_status_update"] = "game_status_update"
|
||||||
game_id: str
|
game_id: str
|
||||||
status: str
|
status: str
|
||||||
|
|
@ -327,6 +379,7 @@ class GameStatusUpdateNotification(BaseNotification):
|
||||||
|
|
||||||
class PowersControllersNotification(BaseNotification):
|
class PowersControllersNotification(BaseNotification):
|
||||||
"""Power control assignments changed."""
|
"""Power control assignments changed."""
|
||||||
|
|
||||||
name: Literal["powers_controllers"] = "powers_controllers"
|
name: Literal["powers_controllers"] = "powers_controllers"
|
||||||
game_id: str
|
game_id: str
|
||||||
controllers: Dict[str, str] # Power -> Controller mapping
|
controllers: Dict[str, str] # Power -> Controller mapping
|
||||||
|
|
@ -334,6 +387,7 @@ class PowersControllersNotification(BaseNotification):
|
||||||
|
|
||||||
class PowerOrdersUpdateNotification(BaseNotification):
|
class PowerOrdersUpdateNotification(BaseNotification):
|
||||||
"""Player submitted new orders."""
|
"""Player submitted new orders."""
|
||||||
|
|
||||||
name: Literal["power_orders_update"] = "power_orders_update"
|
name: Literal["power_orders_update"] = "power_orders_update"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -343,6 +397,7 @@ class PowerOrdersUpdateNotification(BaseNotification):
|
||||||
|
|
||||||
class PowerOrdersFlagNotification(BaseNotification):
|
class PowerOrdersFlagNotification(BaseNotification):
|
||||||
"""Player order submission status."""
|
"""Player order submission status."""
|
||||||
|
|
||||||
name: Literal["power_orders_flag"] = "power_orders_flag"
|
name: Literal["power_orders_flag"] = "power_orders_flag"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -352,6 +407,7 @@ class PowerOrdersFlagNotification(BaseNotification):
|
||||||
|
|
||||||
class PowerWaitFlagNotification(BaseNotification):
|
class PowerWaitFlagNotification(BaseNotification):
|
||||||
"""Player wait flag changed."""
|
"""Player wait flag changed."""
|
||||||
|
|
||||||
name: Literal["power_wait_flag"] = "power_wait_flag"
|
name: Literal["power_wait_flag"] = "power_wait_flag"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -360,6 +416,7 @@ class PowerWaitFlagNotification(BaseNotification):
|
||||||
|
|
||||||
class GameMessageReceivedNotification(BaseNotification):
|
class GameMessageReceivedNotification(BaseNotification):
|
||||||
"""Diplomatic message received."""
|
"""Diplomatic message received."""
|
||||||
|
|
||||||
name: Literal["game_message_received"] = "game_message_received"
|
name: Literal["game_message_received"] = "game_message_received"
|
||||||
game_id: str
|
game_id: str
|
||||||
sender: str
|
sender: str
|
||||||
|
|
@ -371,6 +428,7 @@ class GameMessageReceivedNotification(BaseNotification):
|
||||||
|
|
||||||
class VoteUpdatedNotification(BaseNotification):
|
class VoteUpdatedNotification(BaseNotification):
|
||||||
"""Draw votes changed (omniscient view)."""
|
"""Draw votes changed (omniscient view)."""
|
||||||
|
|
||||||
name: Literal["vote_updated"] = "vote_updated"
|
name: Literal["vote_updated"] = "vote_updated"
|
||||||
game_id: str
|
game_id: str
|
||||||
votes: Dict[str, str] # Power -> vote mapping
|
votes: Dict[str, str] # Power -> vote mapping
|
||||||
|
|
@ -378,6 +436,7 @@ class VoteUpdatedNotification(BaseNotification):
|
||||||
|
|
||||||
class VoteCountUpdatedNotification(BaseNotification):
|
class VoteCountUpdatedNotification(BaseNotification):
|
||||||
"""Vote count changed (observer view)."""
|
"""Vote count changed (observer view)."""
|
||||||
|
|
||||||
name: Literal["vote_count_updated"] = "vote_count_updated"
|
name: Literal["vote_count_updated"] = "vote_count_updated"
|
||||||
game_id: str
|
game_id: str
|
||||||
count_yes: int
|
count_yes: int
|
||||||
|
|
@ -386,6 +445,7 @@ class VoteCountUpdatedNotification(BaseNotification):
|
||||||
|
|
||||||
class PowerVoteUpdatedNotification(BaseNotification):
|
class PowerVoteUpdatedNotification(BaseNotification):
|
||||||
"""Own power's vote changed."""
|
"""Own power's vote changed."""
|
||||||
|
|
||||||
name: Literal["power_vote_updated"] = "power_vote_updated"
|
name: Literal["power_vote_updated"] = "power_vote_updated"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -394,12 +454,14 @@ class PowerVoteUpdatedNotification(BaseNotification):
|
||||||
|
|
||||||
class GameDeletedNotification(BaseNotification):
|
class GameDeletedNotification(BaseNotification):
|
||||||
"""Game removed from server."""
|
"""Game removed from server."""
|
||||||
|
|
||||||
name: Literal["game_deleted"] = "game_deleted"
|
name: Literal["game_deleted"] = "game_deleted"
|
||||||
game_id: str
|
game_id: str
|
||||||
|
|
||||||
|
|
||||||
class OmniscientUpdatedNotification(BaseNotification):
|
class OmniscientUpdatedNotification(BaseNotification):
|
||||||
"""Observer permissions changed."""
|
"""Observer permissions changed."""
|
||||||
|
|
||||||
name: Literal["omniscient_updated"] = "omniscient_updated"
|
name: Literal["omniscient_updated"] = "omniscient_updated"
|
||||||
game_id: str
|
game_id: str
|
||||||
omniscient_type: str
|
omniscient_type: str
|
||||||
|
|
@ -407,12 +469,14 @@ class OmniscientUpdatedNotification(BaseNotification):
|
||||||
|
|
||||||
class AccountDeletedNotification(BaseNotification):
|
class AccountDeletedNotification(BaseNotification):
|
||||||
"""User account deleted."""
|
"""User account deleted."""
|
||||||
|
|
||||||
name: Literal["account_deleted"] = "account_deleted"
|
name: Literal["account_deleted"] = "account_deleted"
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
|
|
||||||
class ClearedCentersNotification(BaseNotification):
|
class ClearedCentersNotification(BaseNotification):
|
||||||
"""Supply centers cleared."""
|
"""Supply centers cleared."""
|
||||||
|
|
||||||
name: Literal["cleared_centers"] = "cleared_centers"
|
name: Literal["cleared_centers"] = "cleared_centers"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -420,6 +484,7 @@ class ClearedCentersNotification(BaseNotification):
|
||||||
|
|
||||||
class ClearedOrdersNotification(BaseNotification):
|
class ClearedOrdersNotification(BaseNotification):
|
||||||
"""Orders cleared."""
|
"""Orders cleared."""
|
||||||
|
|
||||||
name: Literal["cleared_orders"] = "cleared_orders"
|
name: Literal["cleared_orders"] = "cleared_orders"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -428,6 +493,7 @@ class ClearedOrdersNotification(BaseNotification):
|
||||||
|
|
||||||
class ClearedUnitsNotification(BaseNotification):
|
class ClearedUnitsNotification(BaseNotification):
|
||||||
"""Units cleared."""
|
"""Units cleared."""
|
||||||
|
|
||||||
name: Literal["cleared_units"] = "cleared_units"
|
name: Literal["cleared_units"] = "cleared_units"
|
||||||
game_id: str
|
game_id: str
|
||||||
power_name: str
|
power_name: str
|
||||||
|
|
@ -514,6 +580,7 @@ WebSocketMessage = Union[RequestMessage, ResponseMessage, NotificationMessage]
|
||||||
# Utility Functions
|
# Utility Functions
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
|
def parse_message(data: Dict[str, Any]) -> WebSocketMessage:
|
||||||
"""
|
"""
|
||||||
Parse a raw WebSocket message dictionary into the appropriate pydantic model.
|
Parse a raw WebSocket message dictionary into the appropriate pydantic model.
|
||||||
|
|
|
||||||
|
|
@ -50,13 +50,13 @@ class MultiBotLauncher:
|
||||||
|
|
||||||
# Default power to model mapping
|
# Default power to model mapping
|
||||||
self.default_models = {
|
self.default_models = {
|
||||||
"AUSTRIA": "gpt-3.5-turbo",
|
"AUSTRIA": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"ENGLAND": "gpt-4",
|
"ENGLAND": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"FRANCE": "claude-3-haiku",
|
"FRANCE": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"GERMANY": "gpt-3.5-turbo",
|
"GERMANY": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"ITALY": "gemini-pro",
|
"ITALY": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"RUSSIA": "gpt-4",
|
"RUSSIA": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
"TURKEY": "claude-3-sonnet",
|
"TURKEY": "gemini-2.5-flash-lite-preview-06-17",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def create_game(self, creator_power: str = "FRANCE") -> str:
|
async def create_game(self, creator_power: str = "FRANCE") -> str:
|
||||||
|
|
@ -99,7 +99,12 @@ class MultiBotLauncher:
|
||||||
return game_id
|
return game_id
|
||||||
|
|
||||||
def launch_bot(
|
def launch_bot(
|
||||||
self, power: str, model: str, game_id: str, log_level: str = "INFO"
|
self,
|
||||||
|
power: str,
|
||||||
|
model: str,
|
||||||
|
game_id: str,
|
||||||
|
log_level: str = "INFO",
|
||||||
|
negotiation_rounds: int = 3,
|
||||||
) -> subprocess.Popen:
|
) -> subprocess.Popen:
|
||||||
"""
|
"""
|
||||||
Launch a single bot process.
|
Launch a single bot process.
|
||||||
|
|
@ -134,6 +139,8 @@ class MultiBotLauncher:
|
||||||
game_id,
|
game_id,
|
||||||
"--log-level",
|
"--log-level",
|
||||||
log_level,
|
log_level,
|
||||||
|
"--negotiation-rounds",
|
||||||
|
str(negotiation_rounds),
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info(f"Launching bot for {power} with model {model}")
|
logger.info(f"Launching bot for {power} with model {model}")
|
||||||
|
|
@ -156,7 +163,8 @@ class MultiBotLauncher:
|
||||||
models: Optional[Dict[str, str]] = None,
|
models: Optional[Dict[str, str]] = None,
|
||||||
powers: Optional[List[str]] = None,
|
powers: Optional[List[str]] = None,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
stagger_delay: float = 2.0,
|
stagger_delay: float = 0.5,
|
||||||
|
negotiation_rounds: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch bots for all specified powers.
|
Launch bots for all specified powers.
|
||||||
|
|
@ -180,13 +188,11 @@ class MultiBotLauncher:
|
||||||
model = models.get(power, "gpt-3.5-turbo")
|
model = models.get(power, "gpt-3.5-turbo")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = self.launch_bot(power, model, game_id, log_level)
|
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
|
||||||
self.bot_processes.append(process)
|
self.bot_processes.append(process)
|
||||||
self.process_to_power[process] = power
|
self.process_to_power[process] = power
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})")
|
||||||
f"Launched bot {i + 1}/{len(powers)}: {power} (PID: {process.pid})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stagger the launches to avoid overwhelming the server
|
# Stagger the launches to avoid overwhelming the server
|
||||||
if i < len(powers) - 1: # Don't delay after the last bot
|
if i < len(powers) - 1: # Don't delay after the last bot
|
||||||
|
|
@ -224,9 +230,7 @@ class MultiBotLauncher:
|
||||||
# Process has ended
|
# Process has ended
|
||||||
return_code = process.returncode
|
return_code = process.returncode
|
||||||
power = self.process_to_power.get(process, "UNKNOWN")
|
power = self.process_to_power.get(process, "UNKNOWN")
|
||||||
logger.info(
|
logger.info(f"{power} bot process {process.pid} ended with code {return_code}")
|
||||||
f"{power} bot process {process.pid} ended with code {return_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read any remaining output
|
# Read any remaining output
|
||||||
remaining_output = process.stdout.read()
|
remaining_output = process.stdout.read()
|
||||||
|
|
@ -243,7 +247,7 @@ class MultiBotLauncher:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Use select to check which processes have output ready (Unix only)
|
# Use select to check which processes have output ready (Unix only)
|
||||||
if stdout_fds and hasattr(select, 'select'):
|
if stdout_fds and hasattr(select, "select"):
|
||||||
try:
|
try:
|
||||||
ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout
|
ready_fds, _, _ = select.select(stdout_fds, [], [], 0.1) # 100ms timeout
|
||||||
|
|
||||||
|
|
@ -324,6 +328,7 @@ class MultiBotLauncher:
|
||||||
models: Optional[Dict[str, str]] = None,
|
models: Optional[Dict[str, str]] = None,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
creator_power: str = "FRANCE",
|
creator_power: str = "FRANCE",
|
||||||
|
negotiation_rounds: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a game and launch all bots for a complete game.
|
Create a game and launch all bots for a complete game.
|
||||||
|
|
@ -342,7 +347,12 @@ class MultiBotLauncher:
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
# Launch all bots
|
# Launch all bots
|
||||||
await self.launch_all_bots(game_id, models, log_level=log_level)
|
await self.launch_all_bots(
|
||||||
|
game_id,
|
||||||
|
models,
|
||||||
|
log_level=log_level,
|
||||||
|
negotiation_rounds=negotiation_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
# Monitor the bots
|
# Monitor the bots
|
||||||
self.monitor_bots()
|
self.monitor_bots()
|
||||||
|
|
@ -358,6 +368,7 @@ class MultiBotLauncher:
|
||||||
powers: List[str],
|
powers: List[str],
|
||||||
models: Optional[Dict[str, str]] = None,
|
models: Optional[Dict[str, str]] = None,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
|
negotiation_rounds: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch bots to join an existing game.
|
Launch bots to join an existing game.
|
||||||
|
|
@ -372,7 +383,13 @@ class MultiBotLauncher:
|
||||||
self.game_id = game_id
|
self.game_id = game_id
|
||||||
|
|
||||||
# Launch bots for specified powers
|
# Launch bots for specified powers
|
||||||
await self.launch_all_bots(game_id, models, powers, log_level)
|
await self.launch_all_bots(
|
||||||
|
game_id,
|
||||||
|
models,
|
||||||
|
powers,
|
||||||
|
log_level,
|
||||||
|
negotiation_rounds=negotiation_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
# Monitor the bots
|
# Monitor the bots
|
||||||
self.monitor_bots()
|
self.monitor_bots()
|
||||||
|
|
@ -391,18 +408,16 @@ def parse_arguments():
|
||||||
parser.add_argument("--port", type=int, default=8432, help="Server port")
|
parser.add_argument("--port", type=int, default=8432, help="Server port")
|
||||||
parser.add_argument("--username-base", default="bot", help="Base username for bots")
|
parser.add_argument("--username-base", default="bot", help="Base username for bots")
|
||||||
parser.add_argument("--password", default="password", help="Password for all bots")
|
parser.add_argument("--password", default="password", help="Password for all bots")
|
||||||
parser.add_argument(
|
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
|
||||||
"--game-id", help="Game ID to join (creates new if not specified)"
|
parser.add_argument("--powers", nargs="+", help="Powers to launch bots for (default: all)")
|
||||||
)
|
parser.add_argument("--models", help="Comma-separated list of models in power order")
|
||||||
parser.add_argument(
|
|
||||||
"--powers", nargs="+", help="Powers to launch bots for (default: all)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--models", help="Comma-separated list of models in power order"
|
|
||||||
)
|
|
||||||
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
||||||
|
parser.add_argument("--creator-power", default="FRANCE", help="Power that creates the game")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--creator-power", default="FRANCE", help="Power that creates the game"
|
"--negotiation-rounds",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
@ -425,9 +440,7 @@ async def main():
|
||||||
model_list = [m.strip() for m in args.models.split(",")]
|
model_list = [m.strip() for m in args.models.split(",")]
|
||||||
powers = args.powers or list(launcher.default_models.keys())
|
powers = args.powers or list(launcher.default_models.keys())
|
||||||
if len(model_list) != len(powers):
|
if len(model_list) != len(powers):
|
||||||
logger.error(
|
logger.error(f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})")
|
||||||
f"Number of models ({len(model_list)}) must match number of powers ({len(powers)})"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
models = dict(zip(powers, model_list))
|
models = dict(zip(powers, model_list))
|
||||||
|
|
||||||
|
|
@ -440,6 +453,7 @@ async def main():
|
||||||
powers=powers,
|
powers=powers,
|
||||||
models=models,
|
models=models,
|
||||||
log_level=args.log_level,
|
log_level=args.log_level,
|
||||||
|
negotiation_rounds=args.negotiation_rounds,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create new game and launch all bots
|
# Create new game and launch all bots
|
||||||
|
|
@ -447,6 +461,7 @@ async def main():
|
||||||
models=models,
|
models=models,
|
||||||
log_level=args.log_level,
|
log_level=args.log_level,
|
||||||
creator_power=args.creator_power,
|
creator_power=args.creator_power,
|
||||||
|
negotiation_rounds=args.negotiation_rounds,
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ include = [".", ".."]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
src = ["..", "."]
|
src = ["..", "."]
|
||||||
|
line-length = 150
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["bot_client/tests"]
|
testpaths = ["bot_client/tests"]
|
||||||
|
|
@ -44,8 +45,4 @@ markers = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = ["pytest>=8.4.1", "pytest-asyncio>=0.21.0", "websockets>=15.0.1"]
|
||||||
"pytest>=8.4.1",
|
|
||||||
"pytest-asyncio>=0.21.0",
|
|
||||||
"websockets>=15.0.1",
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -13,25 +13,17 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict, List
|
||||||
import dotenv
|
import dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
from websocket_diplomacy_client import (
|
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
|
||||||
WebSocketDiplomacyClient,
|
|
||||||
connect_to_diplomacy_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diplomacy.utils.exceptions import DiplomacyException, GameIdException
|
|
||||||
|
|
||||||
# Suppress warnings
|
|
||||||
# os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40"
|
|
||||||
# os.environ["GRPC_VERBOSITY"] = "ERROR"
|
|
||||||
# os.environ["ABSL_MIN_LOG_LEVEL"] = "2"
|
|
||||||
# os.environ["GRPC_POLL_STRATEGY"] = "poll"
|
|
||||||
|
|
||||||
|
|
||||||
|
from diplomacy.engine.power import Power
|
||||||
|
from diplomacy.utils.exceptions import GameIdException
|
||||||
|
from diplomacy.communication.notifications import GameStatusUpdate
|
||||||
from diplomacy.engine.message import Message
|
from diplomacy.engine.message import Message
|
||||||
|
|
||||||
from ai_diplomacy.clients import load_model_client
|
from ai_diplomacy.clients import load_model_client
|
||||||
|
|
@ -40,15 +32,16 @@ from ai_diplomacy.game_history import GameHistory
|
||||||
from ai_diplomacy.agent import DiplomacyAgent
|
from ai_diplomacy.agent import DiplomacyAgent
|
||||||
from ai_diplomacy.initialization import initialize_agent_state_ext
|
from ai_diplomacy.initialization import initialize_agent_state_ext
|
||||||
from config import Configuration
|
from config import Configuration
|
||||||
|
from websocket_negotiations import (
|
||||||
|
conduct_strategic_negotiation_round,
|
||||||
|
should_participate_in_negotiations,
|
||||||
|
get_negotiation_delay,
|
||||||
|
)
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
config = Configuration()
|
# TODO: This, but better
|
||||||
|
config: Configuration = None
|
||||||
if config.DEBUG:
|
|
||||||
import tracemalloc
|
|
||||||
|
|
||||||
tracemalloc.start()
|
|
||||||
|
|
||||||
|
|
||||||
class SingleBotPlayer:
|
class SingleBotPlayer:
|
||||||
|
|
@ -70,6 +63,7 @@ class SingleBotPlayer:
|
||||||
hostname: str = "localhost",
|
hostname: str = "localhost",
|
||||||
port: int = 8432,
|
port: int = 8432,
|
||||||
game_id: Optional[str] = None,
|
game_id: Optional[str] = None,
|
||||||
|
negotiation_rounds: int = 3,
|
||||||
):
|
):
|
||||||
assert username is not None
|
assert username is not None
|
||||||
assert password is not None
|
assert password is not None
|
||||||
|
|
@ -87,16 +81,24 @@ class SingleBotPlayer:
|
||||||
# Bot state
|
# Bot state
|
||||||
self.client: WebSocketDiplomacyClient
|
self.client: WebSocketDiplomacyClient
|
||||||
self.agent: DiplomacyAgent
|
self.agent: DiplomacyAgent
|
||||||
self.game_history = GameHistory()
|
self.game_history: GameHistory = GameHistory()
|
||||||
self.running = True
|
self.running = True
|
||||||
self.current_phase = None
|
self.current_phase = None
|
||||||
self.waiting_for_orders = False
|
self.waiting_for_orders = False
|
||||||
self.orders_submitted = False
|
self.orders_submitted = False
|
||||||
|
|
||||||
|
# Negotiation settings
|
||||||
|
self.negotiation_rounds = negotiation_rounds
|
||||||
|
self.current_negotiation_round = 0
|
||||||
|
self.negotiation_complete = False
|
||||||
|
|
||||||
# Track error stats
|
# Track error stats
|
||||||
self.error_stats: Dict[str, Dict[str, int]] = {
|
self.error_stats: Dict[str, Dict[str, int]] = {self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}}
|
||||||
self.model_name: {"conversation_errors": 0, "order_decoding_errors": 0}
|
|
||||||
}
|
# Track messaging patterns for strategic communication
|
||||||
|
self.message_counts: Dict[str, int] = {} # Messages received from each power
|
||||||
|
self.response_counts: Dict[str, int] = {} # Responses sent to each power
|
||||||
|
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
|
||||||
|
|
||||||
# Setup signal handlers for graceful shutdown
|
# Setup signal handlers for graceful shutdown
|
||||||
signal.signal(signal.SIGINT, self._signal_handler)
|
signal.signal(signal.SIGINT, self._signal_handler)
|
||||||
|
|
@ -122,9 +124,7 @@ class SingleBotPlayer:
|
||||||
# Join or create game
|
# Join or create game
|
||||||
if self.game_id:
|
if self.game_id:
|
||||||
logger.info(f"Joining existing game {self.game_id} as {self.power_name}")
|
logger.info(f"Joining existing game {self.game_id} as {self.power_name}")
|
||||||
game = await self.client.join_game(
|
await self.client.join_game(game_id=self.game_id, power_name=self.power_name)
|
||||||
game_id=self.game_id, power_name=self.power_name
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating new game as {self.power_name}")
|
logger.info(f"Creating new game as {self.power_name}")
|
||||||
await self.client.create_game(
|
await self.client.create_game(
|
||||||
|
|
@ -142,9 +142,7 @@ class SingleBotPlayer:
|
||||||
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
|
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
|
||||||
|
|
||||||
# Initialize agent state
|
# Initialize agent state
|
||||||
await initialize_agent_state_ext(
|
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
|
||||||
self.agent, self.client.game, self.game_history, config.log_file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup game event callbacks
|
# Setup game event callbacks
|
||||||
await self._setup_event_callbacks()
|
await self._setup_event_callbacks()
|
||||||
|
|
@ -198,6 +196,16 @@ class SingleBotPlayer:
|
||||||
self.current_phase = new_phase
|
self.current_phase = new_phase
|
||||||
self.game_history.add_phase(new_phase)
|
self.game_history.add_phase(new_phase)
|
||||||
self.orders_submitted = False
|
self.orders_submitted = False
|
||||||
|
self.current_negotiation_round = 0
|
||||||
|
self.negotiation_complete = False
|
||||||
|
|
||||||
|
# Log message statistics at phase transitions
|
||||||
|
if hasattr(self, "message_counts") and self.message_counts:
|
||||||
|
self.log_message_statistics()
|
||||||
|
|
||||||
|
# Check if we should start negotiations for movement phases (not adjustment phases)
|
||||||
|
if new_phase.endswith("M"):
|
||||||
|
await self._handle_negotiation_phase()
|
||||||
|
|
||||||
# Check if we need to submit orders for this new phase
|
# Check if we need to submit orders for this new phase
|
||||||
await self._check_if_orders_needed()
|
await self._check_if_orders_needed()
|
||||||
|
|
@ -223,9 +231,7 @@ class SingleBotPlayer:
|
||||||
def _on_message_received(self, game, notification):
|
def _on_message_received(self, game, notification):
|
||||||
"""Handle incoming diplomatic messages."""
|
"""Handle incoming diplomatic messages."""
|
||||||
message = notification.message
|
message = notification.message
|
||||||
logger.info(
|
logger.info(f"Message received from {message.sender} to {message.recipient}: {message.message}")
|
||||||
f"Message received from {message.sender} to {message.recipient}: {message.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add message to game history
|
# Add message to game history
|
||||||
self.game_history.add_message(
|
self.game_history.add_message(
|
||||||
|
|
@ -235,12 +241,17 @@ class SingleBotPlayer:
|
||||||
message_content=message.message,
|
message_content=message.message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track message patterns
|
||||||
|
if message.recipient == self.power_name and message.sender != self.power_name:
|
||||||
|
self.message_counts[message.sender] = self.message_counts.get(message.sender, 0) + 1
|
||||||
|
self._update_priority_contacts()
|
||||||
|
|
||||||
# If it's a private message to us, consider responding
|
# If it's a private message to us, consider responding
|
||||||
if message.recipient == self.power_name and message.sender != self.power_name:
|
if message.recipient == self.power_name and message.sender != self.power_name:
|
||||||
# Schedule the async processing in the event loop
|
# Schedule the async processing in the event loop
|
||||||
asyncio.create_task(self._consider_message_response(message))
|
asyncio.create_task(self._consider_message_response(message))
|
||||||
|
|
||||||
def _on_status_update(self, game, notification):
|
def _on_status_update(self, game, notification: GameStatusUpdate):
|
||||||
"""Handle game status changes."""
|
"""Handle game status changes."""
|
||||||
logger.info(f"Game status updated: {notification.status}")
|
logger.info(f"Game status updated: {notification.status}")
|
||||||
|
|
||||||
|
|
@ -261,20 +272,50 @@ class SingleBotPlayer:
|
||||||
# Check if it's a phase where we can submit orders
|
# Check if it's a phase where we can submit orders
|
||||||
current_short_phase = self.client.game.current_short_phase
|
current_short_phase = self.client.game.current_short_phase
|
||||||
|
|
||||||
# We submit orders in Movement and Retreat phases
|
logger.debug(f"Checking if orders needed for phase: {current_short_phase}")
|
||||||
if current_short_phase.endswith("M") or current_short_phase.endswith("R"):
|
|
||||||
# Check if we have units that can receive orders
|
# Movement and Retreat phases
|
||||||
orderable_locations = self.client.game.get_orderable_locations(
|
orderable_locations = self.client.game.get_orderable_locations(self.power_name)
|
||||||
self.power_name
|
|
||||||
)
|
|
||||||
if orderable_locations:
|
if orderable_locations:
|
||||||
logger.info(f"Orders needed for phase {current_short_phase}")
|
logger.info(f"Orders needed for {current_short_phase} phase - orderable locations: {orderable_locations}")
|
||||||
self.waiting_for_orders = True
|
self.waiting_for_orders = True
|
||||||
await self._submit_orders()
|
await self._submit_orders()
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"No orderable locations for {self.power_name} in {current_short_phase}")
|
||||||
f"No orderable locations for {self.power_name} in {current_short_phase}"
|
|
||||||
)
|
async def _submit_adjustment_orders(self, action_type, count):
|
||||||
|
"""Submit build or disband orders for adjustment phase."""
|
||||||
|
# FIXME: This whole function is horse shit.
|
||||||
|
return
|
||||||
|
if action_type == "build":
|
||||||
|
# Get buildable locations for this power
|
||||||
|
possible_orders = gather_possible_order(self.client.game, self.power_name)
|
||||||
|
if not possible_orders:
|
||||||
|
logger.warning(f"No possible build orders for {self.power_name}")
|
||||||
|
await self.client.set_orders(self.power_name, [])
|
||||||
|
self.orders_submitted = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter for build orders (usually start with unit type + location)
|
||||||
|
build_orders = [order for order in possible_orders if " - " not in order and any(order.startswith(unit) for unit in ["A ", "F "])]
|
||||||
|
|
||||||
|
logger.info(f"Available build orders for {self.power_name}: {build_orders}")
|
||||||
|
|
||||||
|
# Select up to 'count' build orders
|
||||||
|
selected_orders = build_orders[:count]
|
||||||
|
logger.info(f"Submitting build orders for {self.power_name}: {selected_orders}")
|
||||||
|
|
||||||
|
await self.client.set_orders(self.power_name, selected_orders)
|
||||||
|
|
||||||
|
elif action_type == "disband":
|
||||||
|
# Get current units for disbanding
|
||||||
|
|
||||||
|
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
|
||||||
|
await self.client.set_orders(self.power_name, disband_orders)
|
||||||
|
|
||||||
|
self.orders_submitted = True
|
||||||
|
self.waiting_for_orders = False
|
||||||
|
logger.info(f"Adjustment orders submitted successfully for {self.power_name}")
|
||||||
|
|
||||||
async def _submit_orders(self):
|
async def _submit_orders(self):
|
||||||
"""Generate and submit orders for the current phase."""
|
"""Generate and submit orders for the current phase."""
|
||||||
|
|
@ -282,8 +323,8 @@ class SingleBotPlayer:
|
||||||
logger.debug("Orders already submitted for this phase")
|
logger.debug("Orders already submitted for this phase")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
current_phase = self.client.game.get_current_phase()
|
||||||
logger.info("Generating orders...")
|
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
|
||||||
|
|
||||||
# Get current board state
|
# Get current board state
|
||||||
board_state = self.client.game.get_state()
|
board_state = self.client.game.get_state()
|
||||||
|
|
@ -291,8 +332,10 @@ class SingleBotPlayer:
|
||||||
# Get possible orders
|
# Get possible orders
|
||||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||||
|
|
||||||
|
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
|
||||||
|
|
||||||
if not possible_orders:
|
if not possible_orders:
|
||||||
logger.info("No possible orders, submitting empty order set")
|
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
|
||||||
await self.client.set_orders(self.power_name, [])
|
await self.client.set_orders(self.power_name, [])
|
||||||
self.orders_submitted = True
|
self.orders_submitted = True
|
||||||
return
|
return
|
||||||
|
|
@ -331,22 +374,12 @@ class SingleBotPlayer:
|
||||||
self.waiting_for_orders = False
|
self.waiting_for_orders = False
|
||||||
logger.info("Orders submitted successfully")
|
logger.info("Orders submitted successfully")
|
||||||
# Call the no wait so we don't sit around for the turns to end.
|
# Call the no wait so we don't sit around for the turns to end.
|
||||||
|
# TODO: We probably don't want to call this here.
|
||||||
|
# We want to call it when negotiations end,
|
||||||
self.client.game.no_wait()
|
self.client.game.no_wait()
|
||||||
|
|
||||||
except DiplomacyException as e:
|
|
||||||
logger.error(f"Error submitting orders: {e}", exc_info=True)
|
|
||||||
# FIXME: I don't think we want to do this. Likely want to retry again multiple times.
|
|
||||||
#
|
|
||||||
# Submit empty orders as fallback
|
|
||||||
try:
|
|
||||||
await self.client.set_orders(self.power_name, [])
|
|
||||||
self.orders_submitted = True
|
|
||||||
except Exception as fallback_error:
|
|
||||||
logger.error(f"Failed to submit fallback orders: {fallback_error}")
|
|
||||||
|
|
||||||
async def _analyze_phase_results(self):
|
async def _analyze_phase_results(self):
|
||||||
"""Analyze the results of the previous phase."""
|
"""Analyze the results of the previous phase."""
|
||||||
try:
|
|
||||||
logger.info("Analyzing phase results...")
|
logger.info("Analyzing phase results...")
|
||||||
|
|
||||||
# Get current board state after processing
|
# Get current board state after processing
|
||||||
|
|
@ -366,24 +399,204 @@ class SingleBotPlayer:
|
||||||
|
|
||||||
logger.info("Phase analysis complete")
|
logger.info("Phase analysis complete")
|
||||||
|
|
||||||
except Exception as e:
|
async def _handle_negotiation_phase(self):
|
||||||
logger.error(f"Error analyzing phase results: {e}", exc_info=True)
|
"""Handle the negotiation phase for movement turns."""
|
||||||
|
# Check if we should participate in negotiations
|
||||||
|
if not await should_participate_in_negotiations(self.client, self.agent):
|
||||||
|
logger.info(f"{self.power_name} will not participate in negotiations this phase")
|
||||||
|
self.negotiation_complete = True
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Starting negotiation phase for {self.power_name}")
|
||||||
|
# TODO: This doesn't need a specific number of negotiation rounds, though it should have a top number of messages this turn so they don't blabber on forever.
|
||||||
|
|
||||||
|
# Conduct negotiations for the specified number of rounds
|
||||||
|
for round_num in range(1, self.negotiation_rounds + 1):
|
||||||
|
self.current_negotiation_round = round_num
|
||||||
|
|
||||||
|
logger.info(f"Negotiation round {round_num}/{self.negotiation_rounds} for {self.power_name}")
|
||||||
|
|
||||||
|
# Use strategic negotiation that analyzes recent messages
|
||||||
|
success = await conduct_strategic_negotiation_round(
|
||||||
|
client=self.client,
|
||||||
|
agent=self.agent,
|
||||||
|
game_history=self.game_history,
|
||||||
|
model_error_stats=self.error_stats,
|
||||||
|
log_file_path=config.log_file_path,
|
||||||
|
round_number=round_num,
|
||||||
|
max_rounds=self.negotiation_rounds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.info(f"No messages sent in round {round_num} for {self.power_name}")
|
||||||
|
|
||||||
|
# Wait between rounds to allow other bots to respond
|
||||||
|
if round_num < self.negotiation_rounds:
|
||||||
|
delay = get_negotiation_delay(round_num, self.negotiation_rounds)
|
||||||
|
logger.debug(f"Waiting {delay}s before next negotiation round")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
self.negotiation_complete = True
|
||||||
|
logger.info(f"Negotiation phase complete for {self.power_name}")
|
||||||
|
|
||||||
async def _consider_message_response(self, message: Message):
|
async def _consider_message_response(self, message: Message):
|
||||||
"""Consider whether to respond to a diplomatic message."""
|
"""Consider whether to respond to a diplomatic message."""
|
||||||
try:
|
try:
|
||||||
# Simple logic: if someone greets us, greet back
|
# Only respond to messages directed at us specifically
|
||||||
if any(
|
if message.recipient != self.power_name:
|
||||||
word in message.message.lower() for word in ["hello", "hi", "greetings"]
|
return
|
||||||
):
|
|
||||||
response = f"Hello {message.sender}! Good to hear from you."
|
# Don't respond to our own messages
|
||||||
await self.client.game.send_game_message(
|
if message.sender == self.power_name:
|
||||||
sender=self.power_name, recipient=message.sender, message=response
|
return
|
||||||
|
|
||||||
|
logger.info(f"Considering response to message from {message.sender}: {message.message[:50]}...")
|
||||||
|
|
||||||
|
# Enhanced heuristic: respond to direct questions, proposals, and strategic keywords
|
||||||
|
message_lower = message.message.lower()
|
||||||
|
strategic_keywords = [
|
||||||
|
"alliance",
|
||||||
|
"deal",
|
||||||
|
"propose",
|
||||||
|
"agreement",
|
||||||
|
"support",
|
||||||
|
"attack",
|
||||||
|
"coordinate",
|
||||||
|
"move",
|
||||||
|
"order",
|
||||||
|
"help",
|
||||||
|
"work together",
|
||||||
|
"partner",
|
||||||
|
"enemy",
|
||||||
|
"threat",
|
||||||
|
"negotiate",
|
||||||
|
"discuss",
|
||||||
|
"plan",
|
||||||
|
"strategy",
|
||||||
|
"bounce",
|
||||||
|
"convoy",
|
||||||
|
"retreat",
|
||||||
|
]
|
||||||
|
|
||||||
|
should_respond = any(
|
||||||
|
[
|
||||||
|
"?" in message.message, # Questions
|
||||||
|
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
|
||||||
|
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
|
||||||
|
len(message.message.split()) > 15, # Longer messages suggest they want engagement
|
||||||
|
message.sender in self.priority_contacts, # Priority contacts
|
||||||
|
]
|
||||||
)
|
)
|
||||||
logger.info(f"Sent response to {message.sender}: {response}")
|
|
||||||
|
if should_respond:
|
||||||
|
# Generate a contextual response using AI
|
||||||
|
# Get current game state for context
|
||||||
|
board_state = self.client.get_state()
|
||||||
|
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||||
|
|
||||||
|
# Create a simple conversation context
|
||||||
|
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated()]
|
||||||
|
|
||||||
|
# Generate response using the agent's conversation capabilities
|
||||||
|
responses = await self.agent.client.get_conversation_reply(
|
||||||
|
game=self.client.game,
|
||||||
|
board_state=board_state,
|
||||||
|
power_name=self.power_name,
|
||||||
|
possible_orders=possible_orders,
|
||||||
|
game_history=self.game_history,
|
||||||
|
game_phase=self.client.get_current_short_phase(),
|
||||||
|
log_file_path=config.log_file_path,
|
||||||
|
active_powers=active_powers,
|
||||||
|
agent_goals=self.agent.goals,
|
||||||
|
agent_relationships=self.agent.relationships,
|
||||||
|
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send the first response if any were generated
|
||||||
|
if responses and len(responses) > 0:
|
||||||
|
response_content = responses[0].get("content", "").strip()
|
||||||
|
if response_content:
|
||||||
|
await self.client.send_message(
|
||||||
|
sender=self.power_name,
|
||||||
|
recipient=message.sender,
|
||||||
|
message=response_content,
|
||||||
|
phase=self.client.get_current_short_phase(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to game history
|
||||||
|
self.game_history.add_message(
|
||||||
|
phase_name=self.client.get_current_short_phase(),
|
||||||
|
sender=self.power_name,
|
||||||
|
recipient=message.sender,
|
||||||
|
message_content=response_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track response patterns
|
||||||
|
self.response_counts[message.sender] = self.response_counts.get(message.sender, 0) + 1
|
||||||
|
|
||||||
|
# Add to agent's journal
|
||||||
|
self.agent.add_journal_entry(
|
||||||
|
f"Responded to {message.sender} in {self.client.get_current_short_phase()}: {response_content[:100]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Sent AI response to {message.sender}: {response_content[:50]}...")
|
||||||
|
else:
|
||||||
|
logger.debug(f"AI generated empty response to {message.sender}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"AI generated no responses to {message.sender}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Decided not to respond to message from {message.sender}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error responding to message: {e}")
|
logger.error(f"Error responding to message: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _update_priority_contacts(self) -> None:
|
||||||
|
"""Update the list of priority contacts based on messaging patterns."""
|
||||||
|
# Sort powers by message count (descending) and take top 3-4
|
||||||
|
sorted_contacts = sorted(self.message_counts.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Keep top 4 most active contacts as priority
|
||||||
|
self.priority_contacts = [contact[0] for contact in sorted_contacts[:4]]
|
||||||
|
|
||||||
|
logger.debug(f"Updated priority contacts for {self.power_name}: {self.priority_contacts}")
|
||||||
|
|
||||||
|
def get_message_statistics(self) -> Dict[str, any]:
|
||||||
|
"""Get comprehensive statistics about messaging patterns."""
|
||||||
|
active_powers = [p_name for p_name, p_obj in self.client.powers.items() if not p_obj.is_eliminated() and p_name != self.power_name]
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"power_name": self.power_name,
|
||||||
|
"total_messages_received": sum(self.message_counts.values()),
|
||||||
|
"total_responses_sent": sum(self.response_counts.values()),
|
||||||
|
"message_counts_by_power": dict(self.message_counts),
|
||||||
|
"response_counts_by_power": dict(self.response_counts),
|
||||||
|
"priority_contacts": list(self.priority_contacts),
|
||||||
|
"response_rate_by_power": {},
|
||||||
|
"active_powers": active_powers,
|
||||||
|
"current_phase": self.current_phase,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate response rates
|
||||||
|
for power, received in self.message_counts.items():
|
||||||
|
sent = self.response_counts.get(power, 0)
|
||||||
|
stats["response_rate_by_power"][power] = sent / received if received > 0 else 0.0
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def log_message_statistics(self) -> None:
|
||||||
|
"""Log current message statistics for analysis."""
|
||||||
|
stats = self.get_message_statistics()
|
||||||
|
|
||||||
|
logger.info(f"Message Statistics for {self.power_name}:")
|
||||||
|
logger.info(f" Total messages received: {stats['total_messages_received']}")
|
||||||
|
logger.info(f" Total responses sent: {stats['total_responses_sent']}")
|
||||||
|
logger.info(f" Priority contacts: {stats['priority_contacts']}")
|
||||||
|
|
||||||
|
for power in stats["active_powers"]:
|
||||||
|
received = stats["message_counts_by_power"].get(power, 0)
|
||||||
|
sent = stats["response_counts_by_power"].get(power, 0)
|
||||||
|
rate = stats["response_rate_by_power"].get(power, 0.0)
|
||||||
|
logger.info(f" {power}: {received} received, {sent} sent, {rate:.1%} response rate")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Main bot loop."""
|
"""Main bot loop."""
|
||||||
|
|
@ -408,9 +621,7 @@ class SingleBotPlayer:
|
||||||
else:
|
else:
|
||||||
logger.info("Bot shutting down")
|
logger.info("Bot shutting down")
|
||||||
except GameIdException:
|
except GameIdException:
|
||||||
logger.error(
|
logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
|
||||||
f"Game with id {self.game_id} does not exist on the server. Exiting..."
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
await self.cleanup()
|
await self.cleanup()
|
||||||
|
|
||||||
|
|
@ -437,10 +648,14 @@ def parse_arguments():
|
||||||
parser.add_argument("--password", default="password", help="Bot password")
|
parser.add_argument("--password", default="password", help="Bot password")
|
||||||
parser.add_argument("--power", default="FRANCE", help="Power to control")
|
parser.add_argument("--power", default="FRANCE", help="Power to control")
|
||||||
parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use")
|
parser.add_argument("--model", default="gpt-3.5-turbo", help="AI model to use")
|
||||||
parser.add_argument(
|
parser.add_argument("--game-id", help="Game ID to join (creates new if not specified)")
|
||||||
"--game-id", help="Game ID to join (creates new if not specified)"
|
|
||||||
)
|
|
||||||
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
||||||
|
parser.add_argument(
|
||||||
|
"--negotiation-rounds",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -459,8 +674,9 @@ async def main():
|
||||||
power_name=args.power,
|
power_name=args.power,
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
game_id=args.game_id,
|
game_id=args.game_id,
|
||||||
|
negotiation_rounds=args.negotiation_rounds,
|
||||||
)
|
)
|
||||||
|
config = Configuration(args.power)
|
||||||
await bot.run()
|
await bot.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,7 @@ class WebSocketDiplomacyClient:
|
||||||
but communicates with a remote server via WebSocket connections.
|
but communicates with a remote server via WebSocket connections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False):
|
||||||
self, hostname: str = "localhost", port: int = 8432, use_ssl: bool = False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize the WebSocket client.
|
Initialize the WebSocket client.
|
||||||
|
|
||||||
|
|
@ -199,9 +197,7 @@ class WebSocketDiplomacyClient:
|
||||||
|
|
||||||
return await self.channel.get_available_maps()
|
return await self.channel.get_available_maps()
|
||||||
|
|
||||||
async def set_orders(
|
async def set_orders(self, power_name: str, orders: List[str], wait: Optional[bool] = None) -> None:
|
||||||
self, power_name: str, orders: List[str], wait: Optional[bool] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Set orders for a power.
|
Set orders for a power.
|
||||||
|
|
||||||
|
|
@ -246,9 +242,7 @@ class WebSocketDiplomacyClient:
|
||||||
else:
|
else:
|
||||||
await self.game.no_wait()
|
await self.game.no_wait()
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(self, sender: str, recipient: str, message: str, phase: Optional[str] = None) -> None:
|
||||||
self, sender: str, recipient: str, message: str, phase: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Send a diplomatic message.
|
Send a diplomatic message.
|
||||||
|
|
||||||
|
|
@ -269,6 +263,57 @@ class WebSocketDiplomacyClient:
|
||||||
logger.debug(f"Sending message from {sender} to {recipient}: {message}")
|
logger.debug(f"Sending message from {sender} to {recipient}: {message}")
|
||||||
await self.game.send_game_message(message=msg)
|
await self.game.send_game_message(message=msg)
|
||||||
|
|
||||||
|
async def send_broadcast_message(self, sender: str, message: str, phase: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Send a broadcast message to all active powers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender: Sending power name
|
||||||
|
message: Message content
|
||||||
|
phase: Game phase (uses current phase if None)
|
||||||
|
"""
|
||||||
|
if not self.game:
|
||||||
|
raise DiplomacyException("Must join a game first")
|
||||||
|
|
||||||
|
if phase is None:
|
||||||
|
phase = self.game.current_short_phase
|
||||||
|
|
||||||
|
# Send to all active powers
|
||||||
|
active_powers = [power_name for power_name, power in self.powers.items() if not power.is_eliminated() and power_name != sender]
|
||||||
|
|
||||||
|
for recipient in active_powers:
|
||||||
|
await self.send_message(sender, recipient, message, phase)
|
||||||
|
|
||||||
|
logger.debug(f"Broadcast message from {sender} to {len(active_powers)} powers: {message}")
|
||||||
|
|
||||||
|
async def get_recent_messages(self, phase: Optional[str] = None, limit: int = 50) -> List[Message]:
|
||||||
|
"""
|
||||||
|
Get recent messages from the game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
phase: Filter by specific phase (None for current phase)
|
||||||
|
limit: Maximum number of messages to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recent messages
|
||||||
|
"""
|
||||||
|
if not self.game:
|
||||||
|
raise DiplomacyException("Must join a game first")
|
||||||
|
|
||||||
|
if phase is None:
|
||||||
|
phase = self.game.current_short_phase
|
||||||
|
|
||||||
|
# Get messages from the game and filter by phase
|
||||||
|
all_messages = list(self.messages.values())
|
||||||
|
|
||||||
|
# Filter by phase if specified
|
||||||
|
filtered_messages = [msg for msg in all_messages if msg.phase == phase]
|
||||||
|
|
||||||
|
# Sort by time (most recent first) and limit
|
||||||
|
filtered_messages.sort(key=lambda m: m.time_sent or 0, reverse=True)
|
||||||
|
|
||||||
|
return filtered_messages[:limit]
|
||||||
|
|
||||||
async def process_game(self) -> None:
|
async def process_game(self) -> None:
|
||||||
"""
|
"""
|
||||||
Force the game to process immediately (admin/moderator only).
|
Force the game to process immediately (admin/moderator only).
|
||||||
|
|
@ -289,9 +334,7 @@ class WebSocketDiplomacyClient:
|
||||||
logger.debug("Synchronizing game state")
|
logger.debug("Synchronizing game state")
|
||||||
await self.game.synchronize()
|
await self.game.synchronize()
|
||||||
|
|
||||||
async def get_phase_history(
|
async def get_phase_history(self, from_phase: Optional[str] = None, to_phase: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
self, from_phase: Optional[str] = None, to_phase: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Get phase history for the game.
|
Get phase history for the game.
|
||||||
|
|
||||||
|
|
@ -305,9 +348,7 @@ class WebSocketDiplomacyClient:
|
||||||
if not self.game:
|
if not self.game:
|
||||||
raise DiplomacyException("Must join a game first")
|
raise DiplomacyException("Must join a game first")
|
||||||
|
|
||||||
return await self.game.get_phase_history(
|
return await self.game.get_phase_history(from_phase=from_phase, to_phase=to_phase)
|
||||||
from_phase=from_phase, to_phase=to_phase
|
|
||||||
)
|
|
||||||
|
|
||||||
async def vote(self, power_name: str, vote: str) -> None:
|
async def vote(self, power_name: str, vote: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
326
bot_client/websocket_negotiations.py
Normal file
326
bot_client/websocket_negotiations.py
Normal file
|
|
@ -0,0 +1,326 @@
|
||||||
|
"""
|
||||||
|
WebSocket-specific negotiation logic for single bot players.
|
||||||
|
|
||||||
|
This module provides negotiation capabilities for individual bots connected
|
||||||
|
via WebSocket, adapted from the multi-agent negotiation system in the main
|
||||||
|
ai_diplomacy package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from diplomacy.engine.message import GLOBAL
|
||||||
|
|
||||||
|
from ai_diplomacy.utils import gather_possible_orders
|
||||||
|
|
||||||
|
from ai_diplomacy.agent import DiplomacyAgent
|
||||||
|
from ai_diplomacy.game_history import GameHistory
|
||||||
|
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||||
|
|
||||||
|
|
||||||
|
async def conduct_single_bot_negotiation(
|
||||||
|
client: "WebSocketDiplomacyClient",
|
||||||
|
agent: "DiplomacyAgent",
|
||||||
|
game_history: "GameHistory",
|
||||||
|
model_error_stats: Dict[str, Dict[str, int]],
|
||||||
|
log_file_path: str,
|
||||||
|
max_rounds: int = 3,
|
||||||
|
round_number: int = 1,
|
||||||
|
prioritize_targets: Optional[List[str]] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Conduct negotiation for a single bot during one negotiation round.
|
||||||
|
|
||||||
|
This function handles message generation and sending for one bot during
|
||||||
|
a negotiation phase. Unlike the multi-agent version, this focuses on
|
||||||
|
a single power and sends messages via WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: WebSocket diplomacy client
|
||||||
|
agent: The bot's AI agent
|
||||||
|
game_history: Game history tracker
|
||||||
|
model_error_stats: Error statistics tracking
|
||||||
|
log_file_path: Path for logging
|
||||||
|
max_rounds: Maximum number of negotiation rounds
|
||||||
|
round_number: Current round number (1-indexed)
|
||||||
|
prioritize_targets: Optional list of powers to prioritize for messaging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if messages were sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
power_name = agent.power_name
|
||||||
|
logger.info(f"Starting negotiation round {round_number}/{max_rounds} for {power_name}")
|
||||||
|
|
||||||
|
# Check if this power is eliminated
|
||||||
|
if client.get_power(power_name).is_eliminated():
|
||||||
|
logger.info(f"{power_name} is eliminated, skipping negotiation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if this power has any orderable locations
|
||||||
|
possible_orders = gather_possible_orders(client.game, power_name)
|
||||||
|
if not possible_orders:
|
||||||
|
logger.info(f"No orderable locations for {power_name}, skipping negotiation")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get active powers for context
|
||||||
|
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated()]
|
||||||
|
|
||||||
|
# Prioritize message targets if specified
|
||||||
|
message_targets = prioritize_targets if prioritize_targets else active_powers
|
||||||
|
message_targets = [p for p in message_targets if p in active_powers and p != power_name]
|
||||||
|
|
||||||
|
# Generate conversation messages using the AI agent
|
||||||
|
board_state = client.get_state()
|
||||||
|
|
||||||
|
messages = await agent.client.get_conversation_reply(
|
||||||
|
game=client.game,
|
||||||
|
board_state=board_state,
|
||||||
|
power_name=power_name,
|
||||||
|
possible_orders=possible_orders,
|
||||||
|
game_history=game_history,
|
||||||
|
game_phase=client.get_current_short_phase(),
|
||||||
|
log_file_path=log_file_path,
|
||||||
|
active_powers=active_powers,
|
||||||
|
agent_goals=agent.goals,
|
||||||
|
agent_relationships=agent.relationships,
|
||||||
|
agent_private_diary_str=agent.format_private_diary_for_prompt(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
logger.debug(f"No messages generated for {power_name} in round {round_number}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Process and send each message
|
||||||
|
messages_sent = 0
|
||||||
|
for message in messages:
|
||||||
|
success = await _send_negotiation_message(client, agent, game_history, message, power_name)
|
||||||
|
if success:
|
||||||
|
messages_sent += 1
|
||||||
|
|
||||||
|
logger.info(f"Sent {messages_sent}/{len(messages)} messages for {power_name}")
|
||||||
|
return messages_sent > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_negotiation_message(
|
||||||
|
client: WebSocketDiplomacyClient,
|
||||||
|
agent: DiplomacyAgent,
|
||||||
|
game_history: GameHistory,
|
||||||
|
message: Dict,
|
||||||
|
power_name: str,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send a single negotiation message via WebSocket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: WebSocket diplomacy client
|
||||||
|
agent: The bot's AI agent
|
||||||
|
game_history: Game history tracker
|
||||||
|
message: Message dictionary with content and metadata
|
||||||
|
power_name: Name of the sending power
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if message was sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
# Validate message structure
|
||||||
|
if not isinstance(message, dict) or "content" not in message:
|
||||||
|
logger.warning(f"Invalid message format from {power_name}: {message}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
content = message.get("content", "").strip()
|
||||||
|
if not content:
|
||||||
|
logger.debug(f"Empty message content from {power_name}, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Determine recipient
|
||||||
|
recipient = GLOBAL # Default to global
|
||||||
|
if message.get("message_type") == "private":
|
||||||
|
recipient = message.get("recipient", GLOBAL)
|
||||||
|
# Validate recipient is a valid power
|
||||||
|
if recipient not in client.powers and recipient != GLOBAL:
|
||||||
|
logger.warning(f"Invalid recipient '{recipient}' from {power_name}, sending globally")
|
||||||
|
recipient = GLOBAL
|
||||||
|
|
||||||
|
# Send the message via WebSocket
|
||||||
|
await client.send_message(
|
||||||
|
sender=power_name,
|
||||||
|
recipient=recipient,
|
||||||
|
message=content,
|
||||||
|
phase=client.get_current_short_phase(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to game history
|
||||||
|
game_history.add_message(
|
||||||
|
phase_name=client.get_current_short_phase(),
|
||||||
|
sender=power_name,
|
||||||
|
recipient=recipient,
|
||||||
|
message_content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to agent's journal
|
||||||
|
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
||||||
|
agent.add_journal_entry(f"Sent message {journal_recipient} in {client.get_current_short_phase()}: {content[:100]}...")
|
||||||
|
|
||||||
|
logger.info(f"[{power_name} -> {recipient}] {content[:100]}...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def should_participate_in_negotiations(
|
||||||
|
client: "WebSocketDiplomacyClient",
|
||||||
|
agent: "DiplomacyAgent",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if this bot should participate in negotiations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: WebSocket diplomacy client
|
||||||
|
agent: The bot's AI agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the bot should participate in negotiations
|
||||||
|
"""
|
||||||
|
power_name = agent.power_name
|
||||||
|
|
||||||
|
# Don't negotiate if eliminated
|
||||||
|
if client.get_power(power_name).is_eliminated():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Don't negotiate if no orderable locations
|
||||||
|
possible_orders = gather_possible_orders(client.game, power_name)
|
||||||
|
if not possible_orders:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Only negotiate during movement phases
|
||||||
|
current_phase = client.get_current_short_phase()
|
||||||
|
if not current_phase.endswith("M"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_negotiation_delay(round_number: int, total_rounds: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate delay between negotiation rounds to allow message processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
round_number: Current round number (1-indexed)
|
||||||
|
total_rounds: Total number of rounds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delay in seconds
|
||||||
|
"""
|
||||||
|
# Longer delay in early rounds to allow more strategic messaging
|
||||||
|
base_delay = 10.0 # Base delay between rounds
|
||||||
|
|
||||||
|
if round_number == 1:
|
||||||
|
return base_delay * 1.5 # Extra time for first round
|
||||||
|
elif round_number == total_rounds:
|
||||||
|
return base_delay * 0.5 # Less time for final round
|
||||||
|
else:
|
||||||
|
return base_delay
|
||||||
|
|
||||||
|
|
||||||
|
async def analyze_recent_messages_for_targeting(
|
||||||
|
client: "WebSocketDiplomacyClient",
|
||||||
|
power_name: str,
|
||||||
|
max_messages: int = 20,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Analyze recent messages to identify which powers should be prioritized for negotiations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: WebSocket diplomacy client
|
||||||
|
power_name: Name of the analyzing power
|
||||||
|
max_messages: Maximum number of recent messages to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of power names in order of priority for messaging
|
||||||
|
"""
|
||||||
|
# Get recent messages from current phase
|
||||||
|
recent_messages = await client.get_recent_messages(limit=max_messages)
|
||||||
|
|
||||||
|
# Track who has been active and who has messaged us
|
||||||
|
message_activity = {}
|
||||||
|
direct_messages_to_us = {}
|
||||||
|
|
||||||
|
for message in recent_messages:
|
||||||
|
sender = message.sender
|
||||||
|
recipient = message.recipient
|
||||||
|
|
||||||
|
# Track general activity
|
||||||
|
if sender != power_name:
|
||||||
|
message_activity[sender] = message_activity.get(sender, 0) + 1
|
||||||
|
|
||||||
|
# Track direct messages to us
|
||||||
|
if recipient == power_name and sender != power_name:
|
||||||
|
direct_messages_to_us[sender] = direct_messages_to_us.get(sender, 0) + 1
|
||||||
|
|
||||||
|
# Get all active powers
|
||||||
|
active_powers = [p_name for p_name, p_obj in client.powers.items() if not p_obj.is_eliminated() and p_name != power_name]
|
||||||
|
|
||||||
|
# Prioritize based on: 1) Powers that messaged us directly, 2) Most active powers
|
||||||
|
priority_list = []
|
||||||
|
|
||||||
|
# First, add powers that sent us direct messages (sorted by count)
|
||||||
|
direct_senders = sorted(direct_messages_to_us.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for sender, _ in direct_senders:
|
||||||
|
if sender in active_powers:
|
||||||
|
priority_list.append(sender)
|
||||||
|
|
||||||
|
# Then add other active powers (sorted by activity)
|
||||||
|
remaining_powers = [p for p in active_powers if p not in priority_list]
|
||||||
|
activity_sorted = sorted(remaining_powers, key=lambda p: message_activity.get(p, 0), reverse=True)
|
||||||
|
priority_list.extend(activity_sorted)
|
||||||
|
|
||||||
|
logger.debug(f"Message targeting priority for {power_name}: {priority_list}")
|
||||||
|
return priority_list
|
||||||
|
|
||||||
|
|
||||||
|
async def conduct_strategic_negotiation_round(
|
||||||
|
client: "WebSocketDiplomacyClient",
|
||||||
|
agent: "DiplomacyAgent",
|
||||||
|
game_history: "GameHistory",
|
||||||
|
model_error_stats: Dict[str, Dict[str, int]],
|
||||||
|
log_file_path: str,
|
||||||
|
round_number: int,
|
||||||
|
max_rounds: int = 3,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Conduct a single negotiation round with strategic message targeting.
|
||||||
|
|
||||||
|
This function analyzes recent message activity to determine which powers
|
||||||
|
to prioritize for messaging in this round.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: WebSocket diplomacy client
|
||||||
|
agent: The bot's AI agent
|
||||||
|
game_history: Game history tracker
|
||||||
|
model_error_stats: Error statistics tracking
|
||||||
|
log_file_path: Path for logging
|
||||||
|
round_number: Current round number (1-indexed)
|
||||||
|
max_rounds: Maximum number of negotiation rounds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if messages were sent successfully, False otherwise
|
||||||
|
"""
|
||||||
|
power_name = agent.power_name
|
||||||
|
|
||||||
|
# Analyze recent messages to prioritize targets
|
||||||
|
priority_targets = await analyze_recent_messages_for_targeting(client, power_name)
|
||||||
|
|
||||||
|
# Limit to top 3-4 targets in later rounds to focus conversations
|
||||||
|
if round_number > 1:
|
||||||
|
priority_targets = priority_targets[: min(4, len(priority_targets))]
|
||||||
|
|
||||||
|
logger.info(f"Round {round_number} targets for {power_name}: {priority_targets}")
|
||||||
|
|
||||||
|
# Conduct negotiation with prioritized targets
|
||||||
|
return await conduct_single_bot_negotiation(
|
||||||
|
client=client,
|
||||||
|
agent=agent,
|
||||||
|
game_history=game_history,
|
||||||
|
model_error_stats=model_error_stats,
|
||||||
|
log_file_path=log_file_path,
|
||||||
|
max_rounds=max_rounds,
|
||||||
|
round_number=round_number,
|
||||||
|
prioritize_targets=priority_targets,
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue