mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
WIP: Continual work on the bot client. They successfully play the game, but don't negotiate
Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
parent
c09d5bfe7e
commit
79a31d615d
13 changed files with 2483 additions and 263 deletions
|
|
@ -6,7 +6,7 @@ import warnings
|
|||
|
||||
class Configuration(BaseSettings):
|
||||
DEBUG: bool = False
|
||||
log_file_path: Path
|
||||
log_file_path: Path | None = None
|
||||
DEEPSEEK_API_KEY: str | None = None
|
||||
OPENAI_API_KEY: str | None = None
|
||||
ANTHROPIC_API_KEY: str | None = None
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class MultiBotLauncher:
|
|||
logger.info("Creating new game...")
|
||||
|
||||
# Connect as the game creator
|
||||
creator_username = f"{self.base_username}_{creator_power.lower()}"
|
||||
creator_username = f"{self.base_username}_{creator_power}"
|
||||
client = await connect_to_diplomacy_server(
|
||||
hostname=self.hostname,
|
||||
port=self.port,
|
||||
|
|
@ -105,6 +105,9 @@ class MultiBotLauncher:
|
|||
game_id: str,
|
||||
log_level: str = "INFO",
|
||||
negotiation_rounds: int = 3,
|
||||
connection_timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0,
|
||||
) -> subprocess.Popen:
|
||||
"""
|
||||
Launch a single bot process.
|
||||
|
|
@ -141,6 +144,12 @@ class MultiBotLauncher:
|
|||
log_level,
|
||||
"--negotiation-rounds",
|
||||
str(negotiation_rounds),
|
||||
"--connection-timeout",
|
||||
str(connection_timeout),
|
||||
"--max-retries",
|
||||
str(max_retries),
|
||||
"--retry-delay",
|
||||
str(retry_delay),
|
||||
]
|
||||
|
||||
logger.info(f"Launching bot for {power} with model {model}")
|
||||
|
|
@ -165,6 +174,9 @@ class MultiBotLauncher:
|
|||
log_level: str = "INFO",
|
||||
stagger_delay: float = 0.5,
|
||||
negotiation_rounds: int = 3,
|
||||
connection_timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Launch bots for all specified powers.
|
||||
|
|
@ -188,7 +200,7 @@ class MultiBotLauncher:
|
|||
model = models.get(power, "gpt-3.5-turbo")
|
||||
|
||||
try:
|
||||
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
|
||||
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds, connection_timeout, max_retries, retry_delay)
|
||||
self.bot_processes.append(process)
|
||||
self.process_to_power[process] = power
|
||||
|
||||
|
|
@ -266,7 +278,7 @@ class MultiBotLauncher:
|
|||
break
|
||||
print(f"{power}_{process.pid}: {line.strip()}")
|
||||
lines_read += 1
|
||||
except:
|
||||
except Exception:
|
||||
break
|
||||
|
||||
except (OSError, ValueError):
|
||||
|
|
@ -299,7 +311,7 @@ class MultiBotLauncher:
|
|||
break
|
||||
print(f"{power}_{process.pid}: {line.strip()}")
|
||||
lines_read += 1
|
||||
except:
|
||||
except Exception:
|
||||
break
|
||||
|
||||
def stop_all_bots(self):
|
||||
|
|
@ -329,6 +341,9 @@ class MultiBotLauncher:
|
|||
log_level: str = "INFO",
|
||||
creator_power: str = "FRANCE",
|
||||
negotiation_rounds: int = 3,
|
||||
connection_timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Create a game and launch all bots for a complete game.
|
||||
|
|
@ -352,6 +367,9 @@ class MultiBotLauncher:
|
|||
models,
|
||||
log_level=log_level,
|
||||
negotiation_rounds=negotiation_rounds,
|
||||
connection_timeout=connection_timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
)
|
||||
|
||||
# Monitor the bots
|
||||
|
|
@ -369,6 +387,9 @@ class MultiBotLauncher:
|
|||
models: Optional[Dict[str, str]] = None,
|
||||
log_level: str = "INFO",
|
||||
negotiation_rounds: int = 3,
|
||||
connection_timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Launch bots to join an existing game.
|
||||
|
|
@ -389,6 +410,9 @@ class MultiBotLauncher:
|
|||
powers,
|
||||
log_level,
|
||||
negotiation_rounds=negotiation_rounds,
|
||||
connection_timeout=connection_timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
)
|
||||
|
||||
# Monitor the bots
|
||||
|
|
@ -419,12 +443,33 @@ def parse_arguments():
|
|||
default=3,
|
||||
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connection-timeout",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Timeout for network operations in seconds (default: 30.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for failed operations (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Base delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
|
||||
# FIXME: Arg parse appears to not like game ids with hypens in the name. e.g.
|
||||
# uv run python multi_bot_launcher.py --game-id "-1D0i-fobmvprIh1" results in an error
|
||||
args = parse_arguments()
|
||||
|
||||
launcher = MultiBotLauncher(
|
||||
|
|
@ -454,6 +499,9 @@ async def main():
|
|||
models=models,
|
||||
log_level=args.log_level,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
connection_timeout=args.connection_timeout,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
else:
|
||||
# Create new game and launch all bots
|
||||
|
|
@ -462,6 +510,9 @@ async def main():
|
|||
log_level=args.log_level,
|
||||
creator_power=args.creator_power,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
connection_timeout=args.connection_timeout,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
import time
|
||||
from typing import Optional, Dict, List
|
||||
import dotenv
|
||||
from loguru import logger
|
||||
|
|
@ -21,8 +22,7 @@ from loguru import logger
|
|||
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
|
||||
|
||||
|
||||
from diplomacy.engine.power import Power
|
||||
from diplomacy.utils.exceptions import GameIdException
|
||||
from diplomacy.utils.exceptions import GameIdException, DiplomacyException
|
||||
from diplomacy.communication.notifications import GameStatusUpdate
|
||||
from diplomacy.engine.message import Message
|
||||
|
||||
|
|
@ -41,7 +41,6 @@ from websocket_negotiations import (
|
|||
dotenv.load_dotenv()
|
||||
|
||||
# TODO: This, but better
|
||||
config: Configuration = None
|
||||
|
||||
|
||||
class SingleBotPlayer:
|
||||
|
|
@ -64,6 +63,9 @@ class SingleBotPlayer:
|
|||
port: int = 8432,
|
||||
game_id: Optional[str] = None,
|
||||
negotiation_rounds: int = 3,
|
||||
connection_timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
|
@ -77,6 +79,7 @@ class SingleBotPlayer:
|
|||
self.power_name = power_name
|
||||
self.model_name = model_name
|
||||
self.game_id = game_id
|
||||
self.config = Configuration(power_name=power_name)
|
||||
|
||||
# Bot state
|
||||
self.client: WebSocketDiplomacyClient
|
||||
|
|
@ -100,6 +103,21 @@ class SingleBotPlayer:
|
|||
self.response_counts: Dict[str, int] = {} # Responses sent to each power
|
||||
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
|
||||
|
||||
# Connection health and fault tolerance (configurable)
|
||||
self.connection_timeout = connection_timeout
|
||||
self.retry_delay = retry_delay
|
||||
self.max_retries = max_retries
|
||||
self.last_successful_operation = time.time()
|
||||
self.connection_failures = 0
|
||||
self.circuit_breaker_open = False
|
||||
self.circuit_breaker_last_failure = 0
|
||||
self.circuit_breaker_timeout = 60.0 # 1 minute before trying again
|
||||
|
||||
# Add graceful shutdown flag
|
||||
self.shutdown_requested = False
|
||||
|
||||
logger.info(f"Fault tolerance config: timeout={connection_timeout}s, max_retries={max_retries}, retry_delay={retry_delay}s")
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
|
|
@ -109,12 +127,80 @@ class SingleBotPlayer:
|
|||
logger.info(f"Received signal {signum}, shutting down...")
|
||||
self.running = False
|
||||
|
||||
def _is_circuit_breaker_open(self) -> bool:
|
||||
"""Check if circuit breaker is open (preventing operations due to failures)."""
|
||||
if not self.circuit_breaker_open:
|
||||
return False
|
||||
|
||||
# Check if timeout has passed and we should try again
|
||||
if time.time() - self.circuit_breaker_last_failure > self.circuit_breaker_timeout:
|
||||
logger.info("Circuit breaker timeout expired, allowing operations")
|
||||
self.circuit_breaker_open = False
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _record_operation_success(self):
|
||||
"""Record a successful operation."""
|
||||
self.last_successful_operation = time.time()
|
||||
self.connection_failures = 0
|
||||
if self.circuit_breaker_open:
|
||||
logger.info("Operation successful, closing circuit breaker")
|
||||
self.circuit_breaker_open = False
|
||||
|
||||
def _record_operation_failure(self):
|
||||
"""Record a failed operation and potentially open circuit breaker."""
|
||||
self.connection_failures += 1
|
||||
logger.warning(f"Operation failed, failure count: {self.connection_failures}")
|
||||
|
||||
if self.connection_failures >= 5: # Open circuit after 5 consecutive failures
|
||||
logger.error("Opening circuit breaker due to repeated failures")
|
||||
self.circuit_breaker_open = True
|
||||
self.circuit_breaker_last_failure = time.time()
|
||||
|
||||
async def _retry_with_backoff(self, operation, *args, **kwargs):
|
||||
"""Execute an operation with exponential backoff retry logic."""
|
||||
if self._is_circuit_breaker_open():
|
||||
raise DiplomacyException("Circuit breaker is open, operation not allowed")
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
result = await asyncio.wait_for(operation(*args, **kwargs), timeout=self.connection_timeout)
|
||||
self._record_operation_success()
|
||||
return result
|
||||
|
||||
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||
last_exception = e
|
||||
logger.warning(f"Operation timeout on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
|
||||
|
||||
except (ConnectionError, DiplomacyException) as e:
|
||||
last_exception = e
|
||||
logger.warning(f"Connection error on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
|
||||
|
||||
# Don't delay after the last attempt
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.info(f"Retrying in {delay:.1f} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries failed
|
||||
self._record_operation_failure()
|
||||
logger.error(f"Operation failed after {self.max_retries + 1} attempts")
|
||||
raise last_exception or Exception("Operation failed with unknown error")
|
||||
|
||||
async def connect_and_initialize(self):
|
||||
"""Connect to the server and initialize the bot."""
|
||||
logger.info(f"Connecting to {self.hostname}:{self.port} as {self.username}")
|
||||
|
||||
# Connect to server
|
||||
self.client = await connect_to_diplomacy_server(
|
||||
# Connect to server with retry logic
|
||||
self.client = await self._retry_with_backoff(
|
||||
connect_to_diplomacy_server,
|
||||
hostname=self.hostname,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
|
|
@ -142,7 +228,7 @@ class SingleBotPlayer:
|
|||
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
|
||||
|
||||
# Initialize agent state
|
||||
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
|
||||
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, self.config.log_file_path)
|
||||
|
||||
# Setup game event callbacks
|
||||
await self._setup_event_callbacks()
|
||||
|
|
@ -187,8 +273,14 @@ class SingleBotPlayer:
|
|||
|
||||
async def _handle_phase_update_async(self, notification):
|
||||
"""Async handler for phase updates."""
|
||||
# Update our game state
|
||||
await self.client.game.synchronize()
|
||||
try:
|
||||
# Update our game state with retry logic
|
||||
await self._retry_with_backoff(self.client.game.synchronize)
|
||||
except Exception as e:
|
||||
# This is a critical error. If we cannot synchronize the game, even with backoffs, we shouldn't continue.
|
||||
logger.critical(f"Failed to synchronize game state during phase update: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
new_phase = self.client.game.get_current_phase()
|
||||
if new_phase != self.current_phase:
|
||||
|
|
@ -219,14 +311,20 @@ class SingleBotPlayer:
|
|||
|
||||
async def _handle_game_processed_async(self):
|
||||
"""Async handler for game processing."""
|
||||
# Synchronize to get the results
|
||||
await self.client.game.synchronize()
|
||||
try:
|
||||
# Synchronize to get the results with retry logic
|
||||
await self._retry_with_backoff(self.client.game.synchronize)
|
||||
|
||||
# Analyze the results
|
||||
await self._analyze_phase_results()
|
||||
# Analyze the results
|
||||
await self._analyze_phase_results()
|
||||
|
||||
self.orders_submitted = False
|
||||
self.waiting_for_orders = False
|
||||
self.orders_submitted = False
|
||||
self.waiting_for_orders = False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to handle game processing: {e}")
|
||||
# Reset state even if synchronization failed
|
||||
self.orders_submitted = False
|
||||
self.waiting_for_orders = False
|
||||
|
||||
def _on_message_received(self, game, notification):
|
||||
"""Handle incoming diplomatic messages."""
|
||||
|
|
@ -289,7 +387,7 @@ class SingleBotPlayer:
|
|||
return
|
||||
if action_type == "build":
|
||||
# Get buildable locations for this power
|
||||
possible_orders = gather_possible_order(self.client.game, self.power_name)
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
if not possible_orders:
|
||||
logger.warning(f"No possible build orders for {self.power_name}")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
|
|
@ -309,6 +407,16 @@ class SingleBotPlayer:
|
|||
|
||||
elif action_type == "disband":
|
||||
# Get current units for disbanding
|
||||
power = self.client.game.get_power(self.power_name)
|
||||
current_units = list(power.units.keys())
|
||||
|
||||
# Create disband orders
|
||||
disband_orders = []
|
||||
for i, unit_location in enumerate(current_units[:count]):
|
||||
# Format: "A Berlin - DISBAND" or "F London - DISBAND"
|
||||
unit_type = power.units[unit_location][0] # 'A' or 'F'
|
||||
disband_order = f"{unit_type} {unit_location} - DISBAND"
|
||||
disband_orders.append(disband_order)
|
||||
|
||||
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
|
||||
await self.client.set_orders(self.power_name, disband_orders)
|
||||
|
|
@ -323,60 +431,74 @@ class SingleBotPlayer:
|
|||
logger.debug("Orders already submitted for this phase")
|
||||
return
|
||||
|
||||
current_phase = self.client.game.get_current_phase()
|
||||
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
|
||||
try:
|
||||
current_phase = self.client.game.get_current_phase()
|
||||
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
|
||||
|
||||
# Get current board state
|
||||
board_state = self.client.game.get_state()
|
||||
# Get current board state
|
||||
board_state = self.client.game.get_state()
|
||||
|
||||
# Get possible orders
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
# Get possible orders
|
||||
possible_orders = gather_possible_orders(self.client.game, self.power_name)
|
||||
|
||||
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
|
||||
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
|
||||
|
||||
if not possible_orders:
|
||||
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
return
|
||||
if not possible_orders:
|
||||
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
|
||||
await self._retry_with_backoff(self.client.set_orders, self.power_name, [])
|
||||
self.orders_submitted = True
|
||||
return
|
||||
|
||||
# Generate orders using AI
|
||||
orders = await get_valid_orders(
|
||||
game=self.client.game,
|
||||
client=self.agent.client,
|
||||
board_state=board_state,
|
||||
power_name=self.power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||
phase=self.client.game.get_current_phase(),
|
||||
)
|
||||
|
||||
# Submit orders
|
||||
if orders:
|
||||
logger.info(f"Submitting orders: {orders}")
|
||||
await self.client.set_orders(self.power_name, orders)
|
||||
|
||||
# Generate order diary entry
|
||||
await self.agent.generate_order_diary_entry(
|
||||
self.client.game,
|
||||
orders,
|
||||
config.log_file_path,
|
||||
# Generate orders using AI
|
||||
orders = await get_valid_orders(
|
||||
game=self.client.game,
|
||||
client=self.agent.client,
|
||||
board_state=board_state,
|
||||
power_name=self.power_name,
|
||||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
|
||||
phase=self.client.game.get_current_phase(),
|
||||
)
|
||||
else:
|
||||
logger.info("No valid orders generated, submitting empty order set")
|
||||
await self.client.set_orders(self.power_name, [])
|
||||
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
logger.info("Orders submitted successfully")
|
||||
# Call the no wait so we don't sit around for the turns to end.
|
||||
# TODO: We probably don't want to call this here.
|
||||
# We want to call it when negotiations end,
|
||||
self.client.game.no_wait()
|
||||
# Submit orders with retry logic
|
||||
if orders:
|
||||
logger.info(f"Submitting orders: {orders}")
|
||||
await self._retry_with_backoff(self.client.set_orders, self.power_name, orders)
|
||||
|
||||
# Generate order diary entry (don't retry this if it fails)
|
||||
try:
|
||||
await self.agent.generate_order_diary_entry(
|
||||
self.client.game,
|
||||
orders,
|
||||
self.config.log_file_path,
|
||||
)
|
||||
except Exception as diary_error:
|
||||
logger.warning(f"Failed to generate order diary entry: {diary_error}")
|
||||
else:
|
||||
logger.info("No valid orders generated, submitting empty order set")
|
||||
await self._retry_with_backoff(self.client.set_orders, self.power_name, [])
|
||||
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
logger.info("Orders submitted successfully")
|
||||
|
||||
# Call the no wait so we don't sit around for the turns to end.
|
||||
# TODO: We probably don't want to call this here.
|
||||
# We want to call it when negotiations end,
|
||||
try:
|
||||
self.client.game.no_wait()
|
||||
except Exception as no_wait_error:
|
||||
logger.warning(f"Failed to call no_wait: {no_wait_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit orders: {e}", exc_info=True)
|
||||
# Mark as submitted to avoid infinite retry loops
|
||||
self.orders_submitted = True
|
||||
self.waiting_for_orders = False
|
||||
|
||||
async def _analyze_phase_results(self):
|
||||
"""Analyze the results of the previous phase."""
|
||||
|
|
@ -394,7 +516,7 @@ class SingleBotPlayer:
|
|||
board_state=board_state,
|
||||
phase_summary=phase_summary,
|
||||
game_history=self.game_history,
|
||||
log_file_path=config.log_file_path,
|
||||
log_file_path=self.config.log_file_path,
|
||||
)
|
||||
|
||||
logger.info("Phase analysis complete")
|
||||
|
|
@ -422,7 +544,7 @@ class SingleBotPlayer:
|
|||
agent=self.agent,
|
||||
game_history=self.game_history,
|
||||
model_error_stats=self.error_stats,
|
||||
log_file_path=config.log_file_path,
|
||||
log_file_path=self.config.log_file_path,
|
||||
round_number=round_num,
|
||||
max_rounds=self.negotiation_rounds,
|
||||
)
|
||||
|
|
@ -505,7 +627,7 @@ class SingleBotPlayer:
|
|||
possible_orders=possible_orders,
|
||||
game_history=self.game_history,
|
||||
game_phase=self.client.get_current_short_phase(),
|
||||
log_file_path=config.log_file_path,
|
||||
log_file_path=self.config.log_file_path,
|
||||
active_powers=active_powers,
|
||||
agent_goals=self.agent.goals,
|
||||
agent_relationships=self.agent.relationships,
|
||||
|
|
@ -516,12 +638,17 @@ class SingleBotPlayer:
|
|||
if responses and len(responses) > 0:
|
||||
response_content = responses[0].get("content", "").strip()
|
||||
if response_content:
|
||||
await self.client.send_message(
|
||||
sender=self.power_name,
|
||||
recipient=message.sender,
|
||||
message=response_content,
|
||||
phase=self.client.get_current_short_phase(),
|
||||
)
|
||||
try:
|
||||
await self._retry_with_backoff(
|
||||
self.client.send_message,
|
||||
sender=self.power_name,
|
||||
recipient=message.sender,
|
||||
message=response_content,
|
||||
phase=self.client.get_current_short_phase(),
|
||||
)
|
||||
except Exception as send_error:
|
||||
logger.warning(f"Failed to send message response: {send_error}")
|
||||
return # Don't record the message if sending failed
|
||||
|
||||
# Add to game history
|
||||
self.game_history.add_message(
|
||||
|
|
@ -607,14 +734,31 @@ class SingleBotPlayer:
|
|||
|
||||
# Main event loop
|
||||
while self.running and not self.client.game.is_game_done:
|
||||
# Synchronize with server periodically
|
||||
await self.client.game.synchronize()
|
||||
try:
|
||||
# Synchronize with server periodically with retry logic
|
||||
await self._retry_with_backoff(self.client.game.synchronize)
|
||||
|
||||
# Check if we need to submit orders
|
||||
await self._check_if_orders_needed()
|
||||
# Check if we need to submit orders
|
||||
await self._check_if_orders_needed()
|
||||
|
||||
# Sleep for a bit before next iteration
|
||||
await asyncio.sleep(5)
|
||||
# Sleep for a bit before next iteration
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logger.info("Bot operation cancelled, shutting down")
|
||||
break
|
||||
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||
logger.warning(f"Timeout in main loop: {e}")
|
||||
# Continue loop but with a longer sleep
|
||||
await asyncio.sleep(10)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main loop: {e}", exc_info=True)
|
||||
# Continue running unless it's a critical error
|
||||
if "circuit breaker" in str(e).lower():
|
||||
logger.error("Circuit breaker opened, waiting before retry")
|
||||
await asyncio.sleep(30)
|
||||
else:
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if self.client.game.is_game_done:
|
||||
logger.info("Game has finished")
|
||||
|
|
@ -622,21 +766,94 @@ class SingleBotPlayer:
|
|||
logger.info("Bot shutting down")
|
||||
except GameIdException:
|
||||
logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logger.info("Bot cancelled or interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in bot: {e}", exc_info=True)
|
||||
finally:
|
||||
await self.cleanup()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
"""Clean up resources with timeout protection."""
|
||||
logger.info("Starting cleanup process...")
|
||||
cleanup_timeout = 15.0 # Maximum time to spend on cleanup
|
||||
|
||||
try:
|
||||
if self.client.game:
|
||||
# We need to leave a game here, otherwise the connection stays for some reason and we can never rejoin
|
||||
self.client.game.leave()
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
logger.info("Cleanup complete")
|
||||
# Use asyncio.wait_for to prevent hanging during cleanup
|
||||
await asyncio.wait_for(self._perform_cleanup(), timeout=cleanup_timeout)
|
||||
logger.info("Cleanup completed successfully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Cleanup timed out after {cleanup_timeout} seconds")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""Perform the actual cleanup operations."""
|
||||
cleanup_tasks = []
|
||||
|
||||
# Game cleanup
|
||||
if hasattr(self, "client") and self.client and hasattr(self.client, "game") and self.client.game:
|
||||
logger.debug("Cleaning up game connection...")
|
||||
try:
|
||||
# Use asyncio.create_task to make game.leave() non-blocking
|
||||
leave_task = asyncio.create_task(self._safe_game_leave())
|
||||
cleanup_tasks.append(leave_task)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating game leave task: {e}")
|
||||
|
||||
# Client cleanup
|
||||
if hasattr(self, "client") and self.client:
|
||||
logger.debug("Cleaning up client connection...")
|
||||
try:
|
||||
close_task = asyncio.create_task(self._safe_client_close())
|
||||
cleanup_tasks.append(close_task)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating client close task: {e}")
|
||||
|
||||
# Wait for all cleanup tasks with individual timeouts
|
||||
if cleanup_tasks:
|
||||
done, pending = await asyncio.wait(
|
||||
cleanup_tasks,
|
||||
timeout=10.0, # 10 second timeout for all cleanup tasks
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
# Cancel any pending tasks
|
||||
for task in pending:
|
||||
logger.warning(f"Cancelling pending cleanup task: {task}")
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cancelling cleanup task: {e}")
|
||||
|
||||
async def _safe_game_leave(self):
|
||||
"""Safely leave the game with timeout."""
|
||||
try:
|
||||
# Some diplomacy client implementations have async leave, others are sync
|
||||
if asyncio.iscoroutinefunction(self.client.game.leave):
|
||||
await asyncio.wait_for(self.client.game.leave(), timeout=5.0)
|
||||
else:
|
||||
# Run synchronous leave in a thread to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.client.game.leave)
|
||||
logger.debug("Successfully left game")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Game leave operation timed out")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error leaving game: {e}")
|
||||
|
||||
async def _safe_client_close(self):
|
||||
"""Safely close the client with timeout."""
|
||||
try:
|
||||
await asyncio.wait_for(self.client.close(), timeout=5.0)
|
||||
logger.debug("Successfully closed client")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Client close operation timed out")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing client: {e}")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments."""
|
||||
|
|
@ -656,28 +873,63 @@ def parse_arguments():
|
|||
default=3,
|
||||
help="Number of negotiation rounds per movement phase (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connection-timeout",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Timeout for network operations in seconds (default: 30.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for failed operations (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Base delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
args = parse_arguments()
|
||||
if not args.username:
|
||||
args.username = f"bot_{args.power}"
|
||||
"""Main entry point with comprehensive error handling."""
|
||||
bot = None
|
||||
try:
|
||||
args = parse_arguments()
|
||||
if not args.username:
|
||||
args.username = f"bot_{args.power}"
|
||||
|
||||
bot = SingleBotPlayer(
|
||||
hostname=args.hostname,
|
||||
port=args.port,
|
||||
username=args.username,
|
||||
password=args.password,
|
||||
power_name=args.power,
|
||||
model_name=args.model,
|
||||
game_id=args.game_id,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
)
|
||||
config = Configuration(args.power)
|
||||
await bot.run()
|
||||
bot = SingleBotPlayer(
|
||||
hostname=args.hostname,
|
||||
port=args.port,
|
||||
username=args.username,
|
||||
password=args.password,
|
||||
power_name=args.power,
|
||||
model_name=args.model,
|
||||
game_id=args.game_id,
|
||||
negotiation_rounds=args.negotiation_rounds,
|
||||
connection_timeout=args.connection_timeout,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
await bot.run()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in main: {e}", exc_info=True)
|
||||
finally:
|
||||
if bot:
|
||||
# Ensure cleanup happens even if there was an error
|
||||
try:
|
||||
await bot.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
logger.error(f"Error during final cleanup: {cleanup_error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Test package initialization
|
||||
# Test package initialization
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import pytest_asyncio
|
|||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from .test_fake_websocket_server import FakeServerManager, FakeWebSocketServer
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
|
|
@ -39,13 +39,13 @@ def credentials() -> Generator[Dict[str, str], None]:
|
|||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
|
||||
async def client() -> AsyncGenerator[WebSocketDiplomacyClient, None]:
|
||||
"""
|
||||
Fixture that provides a TypedWebSocketDiplomacyClient instance.
|
||||
Fixture that provides a WebSocketDiplomacyClient instance.
|
||||
|
||||
The client is configured to connect to the fake server on port 8433.
|
||||
"""
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
yield client
|
||||
|
||||
# Cleanup: close the client connection
|
||||
|
|
@ -58,13 +58,13 @@ async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
|
|||
@pytest.fixture
|
||||
async def authenticated_client(fake_server, client):
|
||||
"""
|
||||
Fixture that provides an authenticated TypedWebSocketDiplomacyClient.
|
||||
Fixture that provides an authenticated WebSocketDiplomacyClient.
|
||||
|
||||
This client is already connected and authenticated, ready for testing
|
||||
game operations.
|
||||
"""
|
||||
await client.connect()
|
||||
token = await client.authenticate("test_user", "test_password")
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
token = client.token
|
||||
assert token is not None
|
||||
assert client.token == token
|
||||
yield client
|
||||
|
|
|
|||
|
|
@ -38,49 +38,41 @@ from models import (
|
|||
class FakeWebSocketServer:
|
||||
"""
|
||||
A fake WebSocket server that implements the Diplomacy protocol for testing.
|
||||
|
||||
|
||||
This server maintains minimal state and responds to requests with valid
|
||||
responses according to the WEBSOCKET.md protocol specification.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.connected_clients: Set[websockets.WebSocketServerProtocol] = set()
|
||||
|
||||
|
||||
# Mock server state
|
||||
self.users = {
|
||||
"test_user": "test_password",
|
||||
"ai_player": "password",
|
||||
"player1": "password"
|
||||
}
|
||||
self.users = {"test_user": "test_password", "ai_player": "password", "player1": "password"}
|
||||
self.tokens: Dict[str, str] = {} # token -> username
|
||||
self.games: Dict[str, Dict[str, Any]] = {}
|
||||
self.game_counter = 1
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""Start the fake WebSocket server."""
|
||||
logger.info(f"Starting fake WebSocket server on {self.host}:{self.port}")
|
||||
self.server = await websockets.serve(
|
||||
self.handle_client,
|
||||
self.host,
|
||||
self.port
|
||||
)
|
||||
self.server = await websockets.serve(self.handle_client, self.host, self.port)
|
||||
logger.info("Fake WebSocket server started")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the fake WebSocket server."""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("Fake WebSocket server stopped")
|
||||
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle a new WebSocket client connection."""
|
||||
self.connected_clients.add(websocket)
|
||||
logger.info(f"Client connected from {websocket.remote_address}")
|
||||
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
await self.handle_message(websocket, message)
|
||||
|
|
@ -90,49 +82,45 @@ class FakeWebSocketServer:
|
|||
logger.error(f"Error handling client: {e}")
|
||||
finally:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
|
||||
async def handle_message(self, websocket, message_text: str):
|
||||
"""Handle an incoming message from a client."""
|
||||
try:
|
||||
# Parse the raw message
|
||||
data = json.loads(message_text)
|
||||
request = parse_message(data)
|
||||
|
||||
|
||||
# Generate response based on request type
|
||||
response = await self.generate_response(request)
|
||||
|
||||
|
||||
# Send response back to client
|
||||
if response:
|
||||
response_data = serialize_message(response)
|
||||
await websocket.send(json.dumps(response_data))
|
||||
logger.debug(f"Sent response: {response.name}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
logger.debug(f"Raw message: {message_text}")
|
||||
|
||||
|
||||
# Send error response if we can extract request_id
|
||||
try:
|
||||
data = json.loads(message_text)
|
||||
request_id = data.get("request_id", str(uuid.uuid4()))
|
||||
error_response = ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="PARSING_ERROR",
|
||||
message=str(e)
|
||||
)
|
||||
error_response = ErrorResponse(request_id=request_id, error_type="PARSING_ERROR", message=str(e))
|
||||
error_data = serialize_message(error_response)
|
||||
await websocket.send(json.dumps(error_data))
|
||||
except:
|
||||
logger.error("Could not send error response")
|
||||
|
||||
|
||||
async def generate_response(self, request) -> Optional[Any]:
|
||||
"""Generate appropriate response for a request."""
|
||||
logger.info(f"Processing request: {request.name}")
|
||||
|
||||
|
||||
# Authentication requests
|
||||
if isinstance(request, SignInRequest):
|
||||
return await self.handle_sign_in(request)
|
||||
|
||||
|
||||
# Channel-level requests (require token)
|
||||
elif isinstance(request, CreateGameRequest):
|
||||
return await self.handle_create_game(request)
|
||||
|
|
@ -140,7 +128,7 @@ class FakeWebSocketServer:
|
|||
return await self.handle_join_game(request)
|
||||
elif isinstance(request, ListGamesRequest):
|
||||
return await self.handle_list_games(request)
|
||||
|
||||
|
||||
# Game-level requests (require token + game context)
|
||||
elif isinstance(request, SetOrdersRequest):
|
||||
return await self.handle_set_orders(request)
|
||||
|
|
@ -148,45 +136,36 @@ class FakeWebSocketServer:
|
|||
return await self.handle_process_game(request)
|
||||
elif isinstance(request, GetAllPossibleOrdersRequest):
|
||||
return await self.handle_get_possible_orders(request)
|
||||
|
||||
|
||||
else:
|
||||
logger.warning(f"Unhandled request type: {type(request)}")
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="UNSUPPORTED_REQUEST",
|
||||
message=f"Request type {request.name} not supported by fake server"
|
||||
request_id=request.request_id, error_type="UNSUPPORTED_REQUEST", message=f"Request type {request.name} not supported by fake server"
|
||||
)
|
||||
|
||||
|
||||
async def handle_sign_in(self, request: SignInRequest) -> Any:
|
||||
"""Handle authentication request."""
|
||||
username = request.username
|
||||
password = request.password
|
||||
|
||||
|
||||
if username in self.users and self.users[username] == password:
|
||||
# Generate auth token
|
||||
token = f"fake_token_{uuid.uuid4().hex[:16]}"
|
||||
self.tokens[token] = username
|
||||
|
||||
return DataTokenResponse(
|
||||
request_id=request.request_id,
|
||||
data=token
|
||||
)
|
||||
|
||||
return DataTokenResponse(request_id=request.request_id, data=token)
|
||||
else:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid username or password"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="AUTHENTICATION_ERROR", message="Invalid username or password")
|
||||
|
||||
async def handle_create_game(self, request: CreateGameRequest) -> Any:
|
||||
"""Handle game creation request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Create a new game
|
||||
game_id = f"GAME_{self.game_counter:04d}"
|
||||
self.game_counter += 1
|
||||
|
||||
|
||||
game_data = {
|
||||
"game_id": game_id,
|
||||
"map_name": request.map_name,
|
||||
|
|
@ -203,102 +182,83 @@ class FakeWebSocketServer:
|
|||
"RUSSIA": {"units": [], "centers": ["MOS", "SEV", "STP", "WAR"], "is_eliminated": False},
|
||||
"TURKEY": {"units": [], "centers": ["ANK", "CON", "SMY"], "is_eliminated": False},
|
||||
},
|
||||
"controlled_powers": {}
|
||||
"controlled_powers": {},
|
||||
}
|
||||
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
|
||||
return DataGameResponse(request_id=request.request_id, data=game_data)
|
||||
|
||||
async def handle_join_game(self, request: JoinGameRequest) -> Any:
|
||||
"""Handle game join request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
game_data = self.games[game_id].copy()
|
||||
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
|
||||
return DataGameResponse(request_id=request.request_id, data=game_data)
|
||||
|
||||
async def handle_list_games(self, request: ListGamesRequest) -> Any:
|
||||
"""Handle list games request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Return simplified game info
|
||||
games_list = []
|
||||
for game_id, game_data in self.games.items():
|
||||
games_list.append({
|
||||
"game_id": game_id,
|
||||
"map_name": game_data["map_name"],
|
||||
"status": game_data["status"],
|
||||
"phase": game_data["phase"],
|
||||
"n_controls": game_data["n_controls"]
|
||||
})
|
||||
|
||||
return DataGamesResponse(
|
||||
request_id=request.request_id,
|
||||
data=games_list
|
||||
)
|
||||
|
||||
games_list.append(
|
||||
{
|
||||
"game_id": game_id,
|
||||
"map_name": game_data["map_name"],
|
||||
"status": game_data["status"],
|
||||
"phase": game_data["phase"],
|
||||
"n_controls": game_data["n_controls"],
|
||||
}
|
||||
)
|
||||
|
||||
return DataGamesResponse(request_id=request.request_id, data=games_list)
|
||||
|
||||
async def handle_set_orders(self, request: SetOrdersRequest) -> Any:
|
||||
"""Handle set orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
# In a real server, we'd validate the orders and store them
|
||||
# For testing, we just acknowledge receipt
|
||||
logger.info(f"Orders received for {request.game_role}: {request.orders}")
|
||||
|
||||
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
|
||||
async def handle_process_game(self, request: ProcessGameRequest) -> Any:
|
||||
"""Handle process game request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
# Simulate game processing
|
||||
game_data = self.games[game_id]
|
||||
current_phase = game_data["phase"]
|
||||
|
||||
|
||||
# Simple phase progression
|
||||
if current_phase == "S1901M":
|
||||
game_data["phase"] = "F1901M"
|
||||
|
|
@ -307,55 +267,48 @@ class FakeWebSocketServer:
|
|||
else:
|
||||
# For testing, cycle back to start
|
||||
game_data["phase"] = "S1902M"
|
||||
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
|
||||
# Send notification to all clients (in real implementation)
|
||||
# For testing, we'll just return OK
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
|
||||
async def handle_get_possible_orders(self, request: GetAllPossibleOrdersRequest) -> Any:
|
||||
"""Handle get possible orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Return mock possible orders
|
||||
possible_orders = {
|
||||
"PAR": ["A PAR H", "A PAR - BUR", "A PAR - PIC", "A PAR - GAS"],
|
||||
"BRE": ["F BRE H", "F BRE - MAO", "F BRE - ENG", "F BRE - PIC"],
|
||||
"MAR": ["A MAR H", "A MAR - GAS", "A MAR - SPA", "A MAR - PIE"],
|
||||
}
|
||||
|
||||
return DataPossibleOrdersResponse(
|
||||
request_id=request.request_id,
|
||||
data=possible_orders
|
||||
)
|
||||
|
||||
|
||||
return DataPossibleOrdersResponse(request_id=request.request_id, data=possible_orders)
|
||||
|
||||
def validate_token(self, token: str) -> bool:
|
||||
"""Validate an authentication token."""
|
||||
return token in self.tokens
|
||||
|
||||
|
||||
def create_auth_error(self, request_id: str) -> ErrorResponse:
|
||||
"""Create a standard authentication error response."""
|
||||
return ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid or missing authentication token"
|
||||
)
|
||||
return ErrorResponse(request_id=request_id, error_type="AUTHENTICATION_ERROR", message="Invalid or missing authentication token")
|
||||
|
||||
|
||||
class FakeServerManager:
|
||||
"""Context manager for the fake WebSocket server."""
|
||||
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.server = FakeWebSocketServer(host, port)
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.server.start()
|
||||
# Give the server a moment to start
|
||||
await asyncio.sleep(0.1)
|
||||
return self.server
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.server.stop()
|
||||
|
||||
|
|
@ -375,4 +328,4 @@ async def run_fake_server(host: str = "localhost", port: int = 8433):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Run the fake server for manual testing
|
||||
asyncio.run(run_fake_server())
|
||||
asyncio.run(run_fake_server())
|
||||
|
|
|
|||
347
bot_client/tests/test_inter_power_messaging.py
Normal file
347
bot_client/tests/test_inter_power_messaging.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""
|
||||
Integration tests for inter-power messaging functionality.
|
||||
|
||||
These tests verify that the enhanced WebSocket client can properly:
|
||||
1. Send messages between powers
|
||||
2. Retrieve and filter messages
|
||||
3. Handle broadcast messaging
|
||||
4. Track message patterns for strategic communication
|
||||
|
||||
The tests use the fake server infrastructure but test real message flow
|
||||
without mocking the transport layer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from typing import Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
|
||||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
|
||||
class TestBasicMessaging:
|
||||
"""Test core messaging functionality between powers."""
|
||||
|
||||
@pytest.fixture
|
||||
async def two_clients(self, fake_server):
|
||||
"""Fixture providing two authenticated clients in the same game."""
|
||||
# Create two clients
|
||||
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client_germany = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
# Connect and authenticate both clients
|
||||
await client_france.connect_and_authenticate("test_user", "test_password")
|
||||
await client_germany.connect_and_authenticate("ai_player", "password")
|
||||
|
||||
# Create a game with first client as FRANCE
|
||||
await client_france.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=2, # Only need 2 powers for testing
|
||||
)
|
||||
|
||||
# Second client joins as GERMANY
|
||||
await client_germany.join_game(game_id=client_france.game_id, power_name="GERMANY")
|
||||
|
||||
# Synchronize both clients
|
||||
await client_france.synchronize()
|
||||
await client_germany.synchronize()
|
||||
|
||||
yield {"france": client_france, "germany": client_germany}
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
await client_france.close()
|
||||
await client_germany.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_direct_message_sending(self, two_clients):
|
||||
"""Test sending a direct message between two powers."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send message from FRANCE to GERMANY
|
||||
test_message = "Hello Germany, shall we form an alliance?"
|
||||
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize both clients to get latest messages
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that Germany received the message
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
# Find our message
|
||||
sent_message = None
|
||||
for msg in germany_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "GERMANY":
|
||||
sent_message = msg
|
||||
break
|
||||
|
||||
assert sent_message is not None, "Message was not received by Germany"
|
||||
assert sent_message.message == test_message
|
||||
assert sent_message.sender == "FRANCE"
|
||||
assert sent_message.recipient == "GERMANY"
|
||||
|
||||
async def test_global_message_broadcasting(self, two_clients):
|
||||
"""Test broadcasting a global message."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send global message from FRANCE
|
||||
test_message = "Greetings to all powers! Let's have a good game."
|
||||
await france_client.send_message(sender="FRANCE", recipient=GLOBAL, message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize both clients
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that both clients can see the global message
|
||||
france_messages = await france_client.get_recent_messages(limit=10)
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
# Find the global message in both clients
|
||||
def find_global_message(messages):
|
||||
for msg in messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == GLOBAL:
|
||||
return msg
|
||||
return None
|
||||
|
||||
france_msg = find_global_message(france_messages)
|
||||
germany_msg = find_global_message(germany_messages)
|
||||
|
||||
assert france_msg is not None, "France should see its own global message"
|
||||
assert germany_msg is not None, "Germany should see France's global message"
|
||||
assert france_msg.message == test_message
|
||||
assert germany_msg.message == test_message
|
||||
|
||||
async def test_broadcast_to_all_active_powers(self, two_clients):
|
||||
"""Test the enhanced broadcast functionality."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send broadcast message to all active powers
|
||||
test_message = "This is a broadcast to all active powers."
|
||||
await france_client.send_broadcast_message(sender="FRANCE", message=test_message)
|
||||
|
||||
# Allow messages to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize clients
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that Germany received the direct message from broadcast
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
broadcast_message = None
|
||||
for msg in germany_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "GERMANY" and msg.message == test_message:
|
||||
broadcast_message = msg
|
||||
break
|
||||
|
||||
assert broadcast_message is not None, "Germany should receive broadcast message"
|
||||
assert broadcast_message.message == test_message
|
||||
|
||||
async def test_message_filtering_by_phase(self, two_clients):
|
||||
"""Test message retrieval filtered by game phase."""
|
||||
france_client = two_clients["france"]
|
||||
|
||||
# Get current phase
|
||||
current_phase = france_client.get_current_short_phase()
|
||||
|
||||
# Send a message
|
||||
test_message = "Phase-specific message"
|
||||
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
await france_client.synchronize()
|
||||
|
||||
# Get messages for current phase
|
||||
phase_messages = await france_client.get_recent_messages(phase=current_phase, limit=10)
|
||||
|
||||
# Find our message
|
||||
found_message = None
|
||||
for msg in phase_messages:
|
||||
if msg.message == test_message:
|
||||
found_message = msg
|
||||
break
|
||||
|
||||
assert found_message is not None, "Message should be found in current phase"
|
||||
assert found_message.phase == current_phase
|
||||
|
||||
|
||||
class TestMessageHistory:
|
||||
"""Test message history and retrieval functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_with_messages(self, fake_server):
|
||||
"""Fixture providing a client with some test messages."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Send several test messages
|
||||
test_messages = [
|
||||
("Hello world!", GLOBAL),
|
||||
("Private message to England", "ENGLAND"),
|
||||
("Another global message", GLOBAL),
|
||||
("Direct to Germany", "GERMANY"),
|
||||
]
|
||||
|
||||
for message, recipient in test_messages:
|
||||
await client.send_message(sender="FRANCE", recipient=recipient, message=message)
|
||||
await asyncio.sleep(0.05) # Small delay between messages
|
||||
|
||||
await client.synchronize()
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_get_recent_messages_limit(self, client_with_messages):
|
||||
"""Test that message limit parameter works correctly."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get recent messages with different limits
|
||||
messages_3 = await client.get_recent_messages(limit=3)
|
||||
messages_2 = await client.get_recent_messages(limit=2)
|
||||
messages_1 = await client.get_recent_messages(limit=1)
|
||||
|
||||
assert len(messages_3) <= 3
|
||||
assert len(messages_2) <= 2
|
||||
assert len(messages_1) <= 1
|
||||
|
||||
# Messages should be in reverse chronological order (most recent first)
|
||||
if len(messages_3) > 1:
|
||||
# Check that timestamps are in descending order
|
||||
for i in range(len(messages_3) - 1):
|
||||
msg1_time = messages_3[i].time_sent or 0
|
||||
msg2_time = messages_3[i + 1].time_sent or 0
|
||||
assert msg1_time >= msg2_time, "Messages should be in reverse chronological order"
|
||||
|
||||
async def test_message_retrieval_by_sender(self, client_with_messages):
|
||||
"""Test filtering messages by sender."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get all recent messages
|
||||
all_messages = await client.get_recent_messages(limit=20)
|
||||
|
||||
# Filter messages from FRANCE
|
||||
france_messages = [msg for msg in all_messages if msg.sender == "FRANCE"]
|
||||
|
||||
# We should have the 4 messages we sent as FRANCE
|
||||
assert len(france_messages) >= 4, f"Expected at least 4 FRANCE messages, got {len(france_messages)}"
|
||||
|
||||
# All messages should be from FRANCE
|
||||
for msg in france_messages:
|
||||
assert msg.sender == "FRANCE"
|
||||
|
||||
async def test_empty_message_history(self, fake_server):
|
||||
"""Test behavior when no messages exist."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("ai_player", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="AUSTRIA",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Get messages when none exist
|
||||
messages = await client.get_recent_messages(limit=10)
|
||||
|
||||
# Should return empty list, not None or error
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) == 0
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class TestMessageErrorHandling:
|
||||
"""Test error handling in messaging functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, fake_server):
|
||||
"""Basic authenticated client for error testing."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("player1", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="ITALY",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_send_message_to_invalid_recipient(self, client):
|
||||
"""Test sending message to non-existent power."""
|
||||
# This should not raise an exception - the server/game should handle invalid recipients
|
||||
try:
|
||||
await client.send_message(sender="ITALY", recipient="INVALID_POWER", message="This should not crash")
|
||||
# If we get here, the call succeeded (which is fine)
|
||||
except Exception as e:
|
||||
# If an exception is raised, it should be a specific diplomacy exception, not a crash
|
||||
assert "INVALID_POWER" in str(e) or "recipient" in str(e).lower()
|
||||
|
||||
async def test_send_empty_message(self, client):
|
||||
"""Test sending empty message."""
|
||||
# Empty messages should be handled gracefully
|
||||
await client.send_message(sender="ITALY", recipient="FRANCE", message="")
|
||||
# If we get here without exception, the empty message was handled properly
|
||||
|
||||
async def test_get_messages_before_game_setup(self, fake_server):
|
||||
"""Test getting messages when no game is joined."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
|
||||
# Try to get messages without joining a game
|
||||
with pytest.raises(Exception): # Should raise some form of exception
|
||||
await client.get_recent_messages()
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
356
bot_client/tests/test_message_functionality_unit.py
Normal file
356
bot_client/tests/test_message_functionality_unit.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
"""
|
||||
Unit tests for the new inter-power messaging functionality.
|
||||
|
||||
These tests focus on testing the specific functionality we added:
|
||||
1. Enhanced WebSocketDiplomacyClient methods
|
||||
2. Message tracking and statistics in SingleBotPlayer
|
||||
3. Negotiation targeting logic
|
||||
|
||||
These are more focused unit tests that don't require full server integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from websocket_negotiations import (
|
||||
analyze_recent_messages_for_targeting,
|
||||
should_participate_in_negotiations,
|
||||
get_negotiation_delay,
|
||||
)
|
||||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
|
||||
class TestWebSocketClientEnhancements:
|
||||
"""Test the enhanced methods we added to WebSocketDiplomacyClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(self):
|
||||
"""Create a mocked WebSocketDiplomacyClient for testing."""
|
||||
client = WebSocketDiplomacyClient()
|
||||
|
||||
# Mock the game and related objects
|
||||
client.game = MagicMock()
|
||||
mock_powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: True), # Eliminated
|
||||
}
|
||||
client.game.powers = mock_powers
|
||||
|
||||
# Mock message history
|
||||
mock_messages = [
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Hello France!", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="Greetings!", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="GLOBAL", message="Hello everyone!", phase="S1901M"),
|
||||
]
|
||||
client.game.messages = {i: msg for i, msg in enumerate(mock_messages)}
|
||||
|
||||
return client
|
||||
|
||||
async def test_send_broadcast_message(self, mock_client):
|
||||
"""Test the new send_broadcast_message method."""
|
||||
# Mock the send_message method and powers property
|
||||
mock_client.send_message = AsyncMock()
|
||||
|
||||
# Mock the powers property to return our mock powers
|
||||
with patch.object(type(mock_client), "powers", new_callable=lambda: property(lambda self: self.game.powers)):
|
||||
# Send broadcast message
|
||||
await mock_client.send_broadcast_message(sender="FRANCE", message="Hello to all active powers!")
|
||||
|
||||
# Should have called send_message for each active power (excluding sender and eliminated)
|
||||
expected_calls = 2 # ENGLAND and GERMANY (not ITALY because eliminated, not FRANCE because sender)
|
||||
assert mock_client.send_message.call_count == expected_calls
|
||||
|
||||
# Check that messages were sent to the right powers
|
||||
call_args_list = mock_client.send_message.call_args_list
|
||||
recipients = [call[1]["recipient"] for call in call_args_list] # Get recipient from kwargs
|
||||
|
||||
assert "ENGLAND" in recipients
|
||||
assert "GERMANY" in recipients
|
||||
assert "FRANCE" not in recipients # Shouldn't send to self
|
||||
assert "ITALY" not in recipients # Shouldn't send to eliminated power
|
||||
|
||||
async def test_get_recent_messages_filtering(self, mock_client):
|
||||
"""Test the new get_recent_messages method with filtering."""
|
||||
# Mock current phase
|
||||
mock_client.get_current_short_phase = MagicMock(return_value="S1901M")
|
||||
|
||||
# Mock the messages property to return our mock messages
|
||||
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
||||
# Test getting recent messages
|
||||
recent_messages = await mock_client.get_recent_messages(limit=5)
|
||||
|
||||
# Should return the mocked messages
|
||||
assert len(recent_messages) <= 5
|
||||
assert len(recent_messages) == 3 # We have 3 mock messages
|
||||
|
||||
# Test phase filtering
|
||||
phase_messages = await mock_client.get_recent_messages(phase="S1901M", limit=10)
|
||||
assert all(msg.phase == "S1901M" for msg in phase_messages)
|
||||
|
||||
async def test_get_recent_messages_limit(self, mock_client):
|
||||
"""Test that the limit parameter works correctly."""
|
||||
# Mock the messages property to return our mock messages
|
||||
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
||||
# Test with limit smaller than available messages
|
||||
limited_messages = await mock_client.get_recent_messages(limit=2)
|
||||
assert len(limited_messages) <= 2
|
||||
|
||||
# Test with limit larger than available messages
|
||||
all_messages = await mock_client.get_recent_messages(limit=100)
|
||||
assert len(all_messages) == 3 # Should not exceed available messages
|
||||
|
||||
|
||||
class TestSingleBotPlayerMessageTracking:
|
||||
"""Test the message tracking functionality in SingleBotPlayer."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_player(self):
|
||||
"""Create a mock SingleBotPlayer for testing."""
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="FRANCE", model_name="test_model")
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock()
|
||||
bot.agent = MagicMock()
|
||||
bot.game_history = MagicMock()
|
||||
|
||||
# Set up powers
|
||||
bot.client.powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_counting_initialization(self, mock_bot_player):
|
||||
"""Test that message tracking starts in clean state."""
|
||||
bot = mock_bot_player
|
||||
|
||||
assert isinstance(bot.message_counts, dict)
|
||||
assert isinstance(bot.response_counts, dict)
|
||||
assert isinstance(bot.priority_contacts, list)
|
||||
|
||||
assert len(bot.message_counts) == 0
|
||||
assert len(bot.response_counts) == 0
|
||||
assert len(bot.priority_contacts) == 0
|
||||
|
||||
def test_priority_contact_updates(self, mock_bot_player):
|
||||
"""Test that priority contacts are updated correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts
|
||||
bot.message_counts = {
|
||||
"ENGLAND": 5,
|
||||
"GERMANY": 3,
|
||||
"ITALY": 2,
|
||||
"AUSTRIA": 1,
|
||||
"RUSSIA": 4,
|
||||
}
|
||||
|
||||
# Update priority contacts
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have top 4 contacts in order of activity
|
||||
assert len(bot.priority_contacts) == 4
|
||||
assert bot.priority_contacts[0] == "ENGLAND" # Highest count (5)
|
||||
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
||||
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
||||
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
||||
|
||||
def test_message_statistics_generation(self, mock_bot_player):
|
||||
"""Test generation of message statistics."""
|
||||
bot = mock_bot_player
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Set up data
|
||||
bot.message_counts = {"ENGLAND": 4, "GERMANY": 2}
|
||||
bot.response_counts = {"ENGLAND": 3, "GERMANY": 1}
|
||||
bot.priority_contacts = ["ENGLAND", "GERMANY"]
|
||||
|
||||
# Generate statistics
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
# Check basic structure
|
||||
assert stats["power_name"] == "FRANCE"
|
||||
assert stats["current_phase"] == "S1901M"
|
||||
assert stats["total_messages_received"] == 6 # 4 + 2
|
||||
assert stats["total_responses_sent"] == 4 # 3 + 1
|
||||
|
||||
# Check response rates
|
||||
assert stats["response_rate_by_power"]["ENGLAND"] == 0.75 # 3/4
|
||||
assert stats["response_rate_by_power"]["GERMANY"] == 0.5 # 1/2
|
||||
|
||||
def test_response_decision_logic(self, mock_bot_player):
|
||||
"""Test the enhanced response decision logic."""
|
||||
bot = mock_bot_player
|
||||
bot.priority_contacts = ["ENGLAND"]
|
||||
|
||||
# Create test messages
|
||||
priority_message = MagicMock()
|
||||
priority_message.sender = "ENGLAND"
|
||||
priority_message.recipient = "FRANCE"
|
||||
priority_message.message = "Hello France!"
|
||||
|
||||
non_priority_message = MagicMock()
|
||||
non_priority_message.sender = "GERMANY"
|
||||
non_priority_message.recipient = "FRANCE"
|
||||
non_priority_message.message = "Hello France!"
|
||||
|
||||
# Test the decision logic (extracted from _consider_message_response)
|
||||
def should_respond(message):
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = ["alliance", "deal", "propose", "agreement"]
|
||||
|
||||
return any(
|
||||
[
|
||||
"?" in message.message,
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]),
|
||||
any(keyword in message_lower for keyword in strategic_keywords),
|
||||
len(message.message.split()) > 15,
|
||||
message.sender in bot.priority_contacts,
|
||||
]
|
||||
)
|
||||
|
||||
# Both should respond due to "hello", but priority logic is tested
|
||||
assert should_respond(priority_message) is True
|
||||
assert should_respond(non_priority_message) is True
|
||||
|
||||
# Test priority contact influence
|
||||
assert priority_message.sender in bot.priority_contacts
|
||||
assert non_priority_message.sender not in bot.priority_contacts
|
||||
|
||||
|
||||
class TestNegotiationTargeting:
|
||||
"""Test the strategic negotiation targeting logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_messages(self):
|
||||
"""Create a mock client with message history."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock recent messages
|
||||
mock_messages = [
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Direct to France", phase="S1901M"),
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Another to France", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="Message to France", phase="S1901M"),
|
||||
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"),
|
||||
]
|
||||
|
||||
client.get_recent_messages = AsyncMock(return_value=mock_messages)
|
||||
client.powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
async def test_analyze_recent_messages_for_targeting(self, mock_client_with_messages):
|
||||
"""Test the message analysis for targeting."""
|
||||
client = mock_client_with_messages
|
||||
|
||||
# Analyze targeting for FRANCE
|
||||
targets = await analyze_recent_messages_for_targeting(client, "FRANCE", max_messages=20)
|
||||
|
||||
# Should return a list of powers
|
||||
assert isinstance(targets, list)
|
||||
assert "FRANCE" not in targets # Should not include self
|
||||
|
||||
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
|
||||
# GERMANY should be second (sent 1 direct message to FRANCE)
|
||||
if len(targets) >= 2:
|
||||
assert targets[0] == "ENGLAND" # Most direct messages
|
||||
if "GERMANY" in targets:
|
||||
germany_index = targets.index("GERMANY")
|
||||
england_index = targets.index("ENGLAND")
|
||||
assert england_index < germany_index # England should come before Germany
|
||||
|
||||
def test_negotiation_delay_calculation(self):
|
||||
"""Test that negotiation delays are calculated correctly."""
|
||||
# Test different round scenarios
|
||||
first_delay = get_negotiation_delay(round_number=1, total_rounds=3)
|
||||
middle_delay = get_negotiation_delay(round_number=2, total_rounds=3)
|
||||
final_delay = get_negotiation_delay(round_number=3, total_rounds=3)
|
||||
|
||||
# First round should have longer delay
|
||||
assert first_delay > middle_delay
|
||||
# Final round should have shorter delay
|
||||
assert final_delay < middle_delay
|
||||
# All delays should be positive
|
||||
assert all(delay > 0 for delay in [first_delay, middle_delay, final_delay])
|
||||
|
||||
async def test_should_participate_in_negotiations(self):
|
||||
"""Test negotiation participation logic."""
|
||||
# Mock client and agent
|
||||
mock_client = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.power_name = "FRANCE"
|
||||
|
||||
# Test case: active power in movement phase
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_client.get_power.return_value = mock_power
|
||||
mock_client.get_current_short_phase.return_value = "S1901M"
|
||||
|
||||
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_orders.return_value = ["A Paris - Hold"] # Has orders
|
||||
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is True
|
||||
|
||||
# Test case: eliminated power
|
||||
mock_power.is_eliminated.return_value = True
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is False
|
||||
|
||||
# Test case: non-movement phase
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_client.get_current_short_phase.return_value = "S1901R" # Retreat phase
|
||||
|
||||
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_orders.return_value = ["A Paris - Hold"]
|
||||
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios with mocked components."""
|
||||
|
||||
def test_message_persistence_across_phases(self):
|
||||
"""Test that message counts persist across multiple game phases."""
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="AUSTRIA", model_name="test_model")
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock()
|
||||
bot.agent = MagicMock()
|
||||
bot.game_history = MagicMock()
|
||||
|
||||
# Simulate message accumulation over phases
|
||||
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
||||
|
||||
for phase in phases:
|
||||
bot.current_phase = phase
|
||||
|
||||
# Simulate receiving messages
|
||||
if phase.endswith("M"): # Movement phases
|
||||
for sender in ["FRANCE", "ENGLAND"]:
|
||||
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
||||
else:
|
||||
# Fewer messages in other phases
|
||||
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Check accumulated counts
|
||||
assert bot.message_counts["FRANCE"] >= 5 # Should have accumulated messages
|
||||
assert bot.message_counts["ENGLAND"] >= 4
|
||||
assert len(bot.priority_contacts) > 0
|
||||
assert bot.priority_contacts[0] == "FRANCE" # Should be top priority
|
||||
403
bot_client/tests/test_message_persistence.py
Normal file
403
bot_client/tests/test_message_persistence.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""
|
||||
Integration tests for message persistence and tracking functionality.
|
||||
|
||||
These tests verify that the SingleBotPlayer can properly:
|
||||
1. Track message counts and response patterns
|
||||
2. Update priority contacts based on messaging activity
|
||||
3. Generate accurate message statistics
|
||||
4. Log statistics during phase transitions
|
||||
5. Maintain message persistence across game phases
|
||||
|
||||
Tests focus on the tracking and statistics features rather than
|
||||
the actual AI message generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict
|
||||
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from diplomacy.engine.message import Message
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestMessageCounting:
|
||||
"""Test message counting and tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_player(self):
|
||||
"""Create a SingleBotPlayer with mocked dependencies for testing."""
|
||||
# Create mock bot player
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="ENGLAND", model_name="test_model", game_id="test_game")
|
||||
|
||||
# Mock the client and agent
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_initial_message_tracking_state(self, mock_bot_player):
|
||||
"""Test that message tracking starts in clean state."""
|
||||
bot = mock_bot_player
|
||||
|
||||
assert isinstance(bot.message_counts, dict)
|
||||
assert isinstance(bot.response_counts, dict)
|
||||
assert isinstance(bot.priority_contacts, list)
|
||||
|
||||
assert len(bot.message_counts) == 0
|
||||
assert len(bot.response_counts) == 0
|
||||
assert len(bot.priority_contacts) == 0
|
||||
|
||||
def test_message_count_tracking(self, mock_bot_player):
|
||||
"""Test that incoming messages are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate receiving messages from different powers
|
||||
test_messages = [
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Hello England!", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Another message", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="ENGLAND", message="Greetings!", phase="S1901M"),
|
||||
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"), # Should not be counted
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Self message", phase="S1901M"), # Should not be counted
|
||||
]
|
||||
|
||||
# Process each message through the message handler
|
||||
for msg in test_messages:
|
||||
# Simulate the message handling logic from _on_message_received
|
||||
bot.game_history.add_message(
|
||||
phase_name=msg.phase,
|
||||
sender=msg.sender,
|
||||
recipient=msg.recipient,
|
||||
message_content=msg.message,
|
||||
)
|
||||
|
||||
# Track message patterns (only for messages TO this bot)
|
||||
if msg.recipient == bot.power_name and msg.sender != bot.power_name:
|
||||
bot.message_counts[msg.sender] = bot.message_counts.get(msg.sender, 0) + 1
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Check counts
|
||||
assert bot.message_counts["FRANCE"] == 2
|
||||
assert bot.message_counts["GERMANY"] == 1
|
||||
assert "ITALY" not in bot.message_counts # Global message not counted
|
||||
assert "ENGLAND" not in bot.message_counts # Self message not counted
|
||||
|
||||
def test_response_count_tracking(self, mock_bot_player):
|
||||
"""Test that outgoing responses are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate sending responses to different powers
|
||||
responses = [
|
||||
("FRANCE", "Thanks for your message!"),
|
||||
("FRANCE", "Another response to France"),
|
||||
("GERMANY", "Hello Germany"),
|
||||
]
|
||||
|
||||
for recipient, message in responses:
|
||||
# Simulate response sending logic
|
||||
bot.response_counts[recipient] = bot.response_counts.get(recipient, 0) + 1
|
||||
|
||||
# Check response counts
|
||||
assert bot.response_counts["FRANCE"] == 2
|
||||
assert bot.response_counts["GERMANY"] == 1
|
||||
|
||||
def test_priority_contact_updates(self, mock_bot_player):
|
||||
"""Test that priority contacts are updated based on message counts."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with different activity levels
|
||||
bot.message_counts = {
|
||||
"FRANCE": 5, # Most active
|
||||
"GERMANY": 3, # Second most active
|
||||
"ITALY": 2, # Third most active
|
||||
"AUSTRIA": 1, # Least active
|
||||
"RUSSIA": 4, # Second highest
|
||||
}
|
||||
|
||||
# Update priority contacts
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have top 4 contacts in order of activity
|
||||
assert len(bot.priority_contacts) == 4
|
||||
assert bot.priority_contacts[0] == "FRANCE" # Highest count (5)
|
||||
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
||||
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
||||
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
||||
# AUSTRIA should not be in top 4
|
||||
|
||||
def test_priority_contacts_with_fewer_powers(self, mock_bot_player):
|
||||
"""Test priority contacts when fewer than 4 powers are active."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with only 2 powers
|
||||
bot.message_counts = {
|
||||
"FRANCE": 3,
|
||||
"GERMANY": 1,
|
||||
}
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have only 2 contacts
|
||||
assert len(bot.priority_contacts) == 2
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
assert bot.priority_contacts[1] == "GERMANY"
|
||||
|
||||
|
||||
class TestMessageStatistics:
|
||||
"""Test message statistics generation and reporting."""
|
||||
|
||||
@pytest.fixture
|
||||
def bot_with_message_data(self, mock_bot_player):
|
||||
"""Bot player with pre-populated message tracking data."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up realistic message and response data
|
||||
bot.message_counts = {
|
||||
"FRANCE": 8,
|
||||
"GERMANY": 5,
|
||||
"ITALY": 3,
|
||||
"RUSSIA": 2,
|
||||
}
|
||||
|
||||
bot.response_counts = {
|
||||
"FRANCE": 6, # 75% response rate
|
||||
"GERMANY": 2, # 40% response rate
|
||||
"ITALY": 3, # 100% response rate
|
||||
"RUSSIA": 0, # 0% response rate
|
||||
}
|
||||
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY", "ITALY", "RUSSIA"]
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client's powers for active power detection
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"RUSSIA": MagicMock(is_eliminated=lambda: False),
|
||||
"AUSTRIA": MagicMock(is_eliminated=lambda: True), # Eliminated
|
||||
"TURKEY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_statistics_generation(self, bot_with_message_data):
|
||||
"""Test that message statistics are generated correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
# Check basic structure
|
||||
assert stats["power_name"] == "ENGLAND"
|
||||
assert stats["current_phase"] == "S1901M"
|
||||
assert isinstance(stats["message_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_rate_by_power"], dict)
|
||||
assert isinstance(stats["priority_contacts"], list)
|
||||
assert isinstance(stats["active_powers"], list)
|
||||
|
||||
# Check calculated values
|
||||
assert stats["total_messages_received"] == 18 # 8+5+3+2
|
||||
assert stats["total_responses_sent"] == 11 # 6+2+3+0
|
||||
|
||||
# Check response rates
|
||||
assert stats["response_rate_by_power"]["FRANCE"] == 0.75 # 6/8
|
||||
assert stats["response_rate_by_power"]["GERMANY"] == 0.4 # 2/5
|
||||
assert stats["response_rate_by_power"]["ITALY"] == 1.0 # 3/3
|
||||
assert stats["response_rate_by_power"]["RUSSIA"] == 0.0 # 0/2
|
||||
|
||||
# Check active powers (should exclude eliminated AUSTRIA and self)
|
||||
active_powers = stats["active_powers"]
|
||||
assert "AUSTRIA" not in active_powers # Eliminated
|
||||
assert "ENGLAND" not in active_powers # Self
|
||||
assert "FRANCE" in active_powers
|
||||
assert "TURKEY" in active_powers
|
||||
|
||||
def test_message_statistics_empty_data(self, mock_bot_player):
|
||||
"""Test statistics generation with no message data."""
|
||||
bot = mock_bot_player
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock empty powers
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
assert stats["total_messages_received"] == 0
|
||||
assert stats["total_responses_sent"] == 0
|
||||
assert len(stats["message_counts_by_power"]) == 0
|
||||
assert len(stats["response_counts_by_power"]) == 0
|
||||
assert len(stats["response_rate_by_power"]) == 0
|
||||
assert len(stats["priority_contacts"]) == 0
|
||||
|
||||
def test_message_statistics_logging(self, bot_with_message_data, caplog):
|
||||
"""Test that message statistics are logged correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
# Call the logging method
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Check that appropriate log messages were generated
|
||||
log_output = caplog.text
|
||||
assert "Message Statistics for ENGLAND" in log_output
|
||||
assert "Total messages received: 18" in log_output
|
||||
assert "Total responses sent: 11" in log_output
|
||||
assert "Priority contacts:" in log_output
|
||||
|
||||
# Check that individual power stats are logged
|
||||
assert "FRANCE:" in log_output
|
||||
assert "GERMANY:" in log_output
|
||||
assert "75%" in log_output # France response rate
|
||||
assert "40%" in log_output # Germany response rate
|
||||
|
||||
|
||||
class TestMessagePersistenceIntegration:
|
||||
"""Test integration of message persistence with game flow."""
|
||||
|
||||
@pytest.fixture
|
||||
async def bot_integration_setup(self, fake_server):
|
||||
"""Setup for integration testing with fake server."""
|
||||
# Note: This is a more complex fixture that would require actual
|
||||
# SingleBotPlayer initialization, which depends on AI client setup
|
||||
# For now, we'll focus on the core tracking logic
|
||||
|
||||
bot = SingleBotPlayer(username="integration_test", password="test_pass", power_name="ITALY", model_name="test_model", game_id=None)
|
||||
|
||||
# Mock the complex dependencies
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_tracking_during_phase_transition(self, bot_integration_setup):
|
||||
"""Test that message statistics are logged during phase transitions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up some message data
|
||||
bot.message_counts = {"FRANCE": 2, "GERMANY": 1}
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client powers for log_message_statistics
|
||||
mock_powers = {
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
# Mock the log_message_statistics method to track if it's called
|
||||
with patch.object(bot, "log_message_statistics") as mock_log:
|
||||
# Simulate phase transition logic
|
||||
new_phase = "F1901M"
|
||||
if new_phase != bot.current_phase:
|
||||
bot.current_phase = new_phase
|
||||
bot.game_history.add_phase(new_phase)
|
||||
bot.orders_submitted = False
|
||||
bot.current_negotiation_round = 0
|
||||
bot.negotiation_complete = False
|
||||
|
||||
# This is the key logic from _handle_phase_update_async
|
||||
if hasattr(bot, "message_counts") and bot.message_counts:
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Verify that statistics were logged
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_priority_contact_influence_on_response_decisions(self, bot_integration_setup):
|
||||
"""Test that priority contacts influence message response decisions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up priority contacts
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY"]
|
||||
|
||||
# Create test messages from different senders
|
||||
priority_message = Message(sender="FRANCE", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
non_priority_message = Message(sender="AUSTRIA", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
# Test the response decision logic (from _consider_message_response)
|
||||
def should_respond_to_message(message):
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = [
|
||||
"alliance",
|
||||
"deal",
|
||||
"propose",
|
||||
"agreement",
|
||||
"support",
|
||||
"attack",
|
||||
"coordinate",
|
||||
"move",
|
||||
"order",
|
||||
"help",
|
||||
"work together",
|
||||
"partner",
|
||||
"enemy",
|
||||
"threat",
|
||||
"negotiate",
|
||||
"discuss",
|
||||
"plan",
|
||||
"strategy",
|
||||
"bounce",
|
||||
"convoy",
|
||||
"retreat",
|
||||
]
|
||||
|
||||
return any(
|
||||
[
|
||||
"?" in message.message, # Questions
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
|
||||
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
|
||||
len(message.message.split()) > 15, # Longer messages
|
||||
message.sender in bot.priority_contacts, # Priority contacts
|
||||
]
|
||||
)
|
||||
|
||||
# Priority contact should be more likely to get response
|
||||
priority_should_respond = should_respond_to_message(priority_message)
|
||||
non_priority_should_respond = should_respond_to_message(non_priority_message)
|
||||
|
||||
# Both should respond due to "hello" keyword, but priority contact logic is working
|
||||
assert priority_should_respond is True
|
||||
# Non-priority should also respond due to "hello", but test the contact logic
|
||||
assert priority_message.sender in bot.priority_contacts
|
||||
assert non_priority_message.sender not in bot.priority_contacts
|
||||
|
||||
def test_message_persistence_across_multiple_phases(self, bot_integration_setup):
|
||||
"""Test that message counts persist across multiple game phases."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Simulate message accumulation over multiple phases
|
||||
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
||||
|
||||
for phase in phases:
|
||||
bot.current_phase = phase
|
||||
|
||||
# Simulate receiving messages in each phase
|
||||
if phase.endswith("M"): # Movement phases
|
||||
# More messages during movement phases
|
||||
for sender in ["FRANCE", "GERMANY"]:
|
||||
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
||||
else:
|
||||
# Fewer messages during other phases
|
||||
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# After all phases, check accumulated counts
|
||||
assert bot.message_counts["FRANCE"] >= 5 # 2+2+1+2 from movement phases + 1 from adjustment
|
||||
assert bot.message_counts["GERMANY"] >= 4 # 2+2+2 from movement phases only
|
||||
|
||||
# France should be top priority due to higher count
|
||||
assert len(bot.priority_contacts) > 0
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
440
bot_client/tests/test_multi_bot_scenarios.py
Normal file
440
bot_client/tests/test_multi_bot_scenarios.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Integration tests for multi-bot messaging scenarios.
|
||||
|
||||
These tests verify realistic inter-power communication scenarios with
|
||||
multiple bots interacting simultaneously. Tests focus on:
|
||||
1. Multi-bot conversation flows
|
||||
2. Negotiation coordination between multiple powers
|
||||
3. Message response patterns in realistic scenarios
|
||||
4. Phase transition coordination across multiple bots
|
||||
|
||||
Tests use the fake server but simulate realistic multi-power interactions
|
||||
without mocking the core message transport logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestTwoBotConversation:
|
||||
"""Test conversation flow between two bots."""
|
||||
|
||||
@pytest.fixture
|
||||
async def two_bot_setup(self, fake_server):
|
||||
"""Setup two bots in the same game for conversation testing."""
|
||||
# Create two WebSocket clients
|
||||
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client_england = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
# Connect both clients
|
||||
await client_france.connect_and_authenticate("france_bot", "password")
|
||||
await client_england.connect_and_authenticate("england_bot", "password")
|
||||
|
||||
# Create game with France
|
||||
await client_france.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=2,
|
||||
)
|
||||
|
||||
# England joins the game
|
||||
await client_england.join_game(game_id=client_france.game_id, power_name="ENGLAND")
|
||||
|
||||
# Synchronize both clients
|
||||
await client_france.synchronize()
|
||||
await client_england.synchronize()
|
||||
|
||||
yield {"france_client": client_france, "england_client": client_england, "game_id": client_france.game_id}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client_france.close()
|
||||
await client_england.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_basic_two_bot_exchange(self, two_bot_setup):
|
||||
"""Test basic message exchange between two bots."""
|
||||
france_client = two_bot_setup["france_client"]
|
||||
england_client = two_bot_setup["england_client"]
|
||||
|
||||
# France initiates conversation
|
||||
await france_client.send_message(sender="FRANCE", recipient="ENGLAND", message="Hello England! Shall we discuss our border?")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await england_client.synchronize()
|
||||
|
||||
# Check England received the message
|
||||
england_messages = await england_client.get_recent_messages(limit=5)
|
||||
france_message = None
|
||||
for msg in england_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "ENGLAND":
|
||||
france_message = msg
|
||||
break
|
||||
|
||||
assert france_message is not None
|
||||
assert "border" in france_message.message
|
||||
|
||||
# England responds
|
||||
await england_client.send_message(sender="ENGLAND", recipient="FRANCE", message="Indeed, France. I propose we coordinate our fleets.")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await france_client.synchronize()
|
||||
|
||||
# Check France received the response
|
||||
france_messages = await france_client.get_recent_messages(limit=5)
|
||||
england_response = None
|
||||
for msg in france_messages:
|
||||
if msg.sender == "ENGLAND" and msg.recipient == "FRANCE":
|
||||
england_response = msg
|
||||
break
|
||||
|
||||
assert england_response is not None
|
||||
assert "coordinate" in england_response.message
|
||||
|
||||
async def test_conversation_thread_tracking(self, two_bot_setup):
|
||||
"""Test that conversation threads can be tracked across multiple exchanges."""
|
||||
france_client = two_bot_setup["france_client"]
|
||||
england_client = two_bot_setup["england_client"]
|
||||
|
||||
# Simulate a conversation thread
|
||||
conversation = [
|
||||
("FRANCE", "ENGLAND", "England, are you interested in an alliance?"),
|
||||
("ENGLAND", "FRANCE", "Yes France, what do you propose?"),
|
||||
("FRANCE", "ENGLAND", "Let's coordinate attacks on Germany."),
|
||||
("ENGLAND", "FRANCE", "Agreed. I'll move my fleet to support you."),
|
||||
]
|
||||
|
||||
for sender, recipient, message in conversation:
|
||||
if sender == "FRANCE":
|
||||
await france_client.send_message(sender, recipient, message)
|
||||
else:
|
||||
await england_client.send_message(sender, recipient, message)
|
||||
|
||||
await asyncio.sleep(0.05) # Small delay between messages
|
||||
|
||||
# Synchronize both clients
|
||||
await france_client.synchronize()
|
||||
await england_client.synchronize()
|
||||
|
||||
# Get conversation history from both perspectives
|
||||
france_messages = await france_client.get_recent_messages(limit=10)
|
||||
england_messages = await england_client.get_recent_messages(limit=10)
|
||||
|
||||
# Filter for conversation between France and England
|
||||
def filter_conversation(messages):
|
||||
return [msg for msg in messages if (msg.sender in ["FRANCE", "ENGLAND"] and msg.recipient in ["FRANCE", "ENGLAND"])]
|
||||
|
||||
france_conv = filter_conversation(france_messages)
|
||||
england_conv = filter_conversation(england_messages)
|
||||
|
||||
# Both should see the same conversation
|
||||
assert len(france_conv) >= 4
|
||||
assert len(england_conv) >= 4
|
||||
|
||||
# Check that key terms from the conversation appear
|
||||
all_messages_text = " ".join([msg.message for msg in france_conv])
|
||||
assert "alliance" in all_messages_text
|
||||
assert "Germany" in all_messages_text
|
||||
assert "support" in all_messages_text
|
||||
|
||||
|
||||
class TestThreeBotNegotiation:
|
||||
"""Test more complex negotiations with three bots."""
|
||||
|
||||
@pytest.fixture
|
||||
async def three_bot_setup(self, fake_server):
|
||||
"""Setup three bots for complex negotiation testing."""
|
||||
clients = {}
|
||||
powers = ["FRANCE", "ENGLAND", "GERMANY"]
|
||||
|
||||
try:
|
||||
# Create and connect three clients
|
||||
for power in powers:
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
await client.connect_and_authenticate(f"{power.lower()}_bot", "password")
|
||||
clients[power] = client
|
||||
|
||||
# Create game with France
|
||||
await clients["FRANCE"].create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=3,
|
||||
)
|
||||
|
||||
game_id = clients["FRANCE"].game_id
|
||||
|
||||
# Other powers join
|
||||
await clients["ENGLAND"].join_game(game_id=game_id, power_name="ENGLAND")
|
||||
await clients["GERMANY"].join_game(game_id=game_id, power_name="GERMANY")
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
yield {"clients": clients, "game_id": game_id}
|
||||
|
||||
finally:
|
||||
for client in clients.values():
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_three_way_alliance_negotiation(self, three_bot_setup):
|
||||
"""Test alliance negotiation between three powers."""
|
||||
clients = three_bot_setup["clients"]
|
||||
|
||||
# Simulate alliance negotiation sequence
|
||||
negotiations = [
|
||||
("FRANCE", "ENGLAND", "England, shall we form an alliance against Germany?"),
|
||||
("ENGLAND", "FRANCE", "I'm interested. What are your terms?"),
|
||||
("FRANCE", "GERMANY", "Germany, France and England are discussing cooperation."),
|
||||
("GERMANY", "FRANCE", "I see. Perhaps we should talk as well."),
|
||||
("GERMANY", "ENGLAND", "England, what is France offering you?"),
|
||||
("ENGLAND", "GERMANY", "Germany, I think we should all work together."),
|
||||
]
|
||||
|
||||
# Send all negotiation messages
|
||||
for sender, recipient, message in negotiations:
|
||||
await clients[sender].send_message(sender, recipient, message)
|
||||
await asyncio.sleep(0.1) # Allow message processing
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
# Analyze message patterns from each perspective
|
||||
for power, client in clients.items():
|
||||
messages = await client.get_recent_messages(limit=20)
|
||||
|
||||
# Count messages involving this power
|
||||
involving_power = [msg for msg in messages if msg.sender == power or msg.recipient == power]
|
||||
|
||||
# Each power should be involved in multiple messages
|
||||
assert len(involving_power) >= 2, f"{power} should be involved in multiple messages"
|
||||
|
||||
# Check that all three powers have communicated
|
||||
all_messages = await clients["FRANCE"].get_recent_messages(limit=20)
|
||||
|
||||
senders = set(msg.sender for msg in all_messages)
|
||||
recipients = set(msg.recipient for msg in all_messages)
|
||||
|
||||
# All three powers should appear as senders
|
||||
assert "FRANCE" in senders
|
||||
assert "ENGLAND" in senders
|
||||
assert "GERMANY" in senders
|
||||
|
||||
async def test_broadcast_with_private_follow_ups(self, three_bot_setup):
|
||||
"""Test broadcast message followed by private conversations."""
|
||||
clients = three_bot_setup["clients"]
|
||||
|
||||
# France sends a global announcement
|
||||
await clients["FRANCE"].send_broadcast_message(sender="FRANCE", message="All powers: I propose we establish clear spheres of influence.")
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Follow up with private messages to each power
|
||||
private_messages = [
|
||||
("FRANCE", "ENGLAND", "England, I suggest you focus on the seas."),
|
||||
("FRANCE", "GERMANY", "Germany, the eastern approach might suit you."),
|
||||
]
|
||||
|
||||
for sender, recipient, message in private_messages:
|
||||
await clients[sender].send_message(sender, recipient, message)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
# Check that England and Germany received their specific messages
|
||||
england_messages = await clients["ENGLAND"].get_recent_messages(limit=10)
|
||||
germany_messages = await clients["GERMANY"].get_recent_messages(limit=10)
|
||||
|
||||
# England should see both broadcast (to all) and private message
|
||||
england_private = [msg for msg in england_messages if msg.sender == "FRANCE" and msg.recipient == "ENGLAND"]
|
||||
england_broadcast = [msg for msg in england_messages if msg.sender == "FRANCE" and "spheres of influence" in msg.message]
|
||||
|
||||
assert len(england_private) >= 1, "England should receive private message"
|
||||
assert any("seas" in msg.message for msg in england_private), "England should get seas message"
|
||||
|
||||
# Germany should see broadcast and their private message
|
||||
germany_private = [msg for msg in germany_messages if msg.sender == "FRANCE" and msg.recipient == "GERMANY"]
|
||||
|
||||
assert len(germany_private) >= 1, "Germany should receive private message"
|
||||
assert any("eastern" in msg.message for msg in germany_private), "Germany should get eastern message"
|
||||
|
||||
|
||||
class TestBotPlayerMessageIntegration:
|
||||
"""Test SingleBotPlayer message handling in multi-bot scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_players(self):
|
||||
"""Create multiple mock SingleBotPlayer instances."""
|
||||
bots = {}
|
||||
powers = ["FRANCE", "ENGLAND", "GERMANY"]
|
||||
|
||||
for power in powers:
|
||||
bot = SingleBotPlayer(
|
||||
username=f"{power.lower()}_bot", password="test_pass", power_name=power, model_name="test_model", game_id="test_game"
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
# Set up power name correctly
|
||||
bot.agent.power_name = power
|
||||
|
||||
bots[power] = bot
|
||||
|
||||
return bots
|
||||
|
||||
def test_multi_bot_priority_contact_evolution(self, mock_bot_players):
|
||||
"""Test how priority contacts evolve in multi-bot scenarios."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Simulate asymmetric communication patterns
|
||||
# France talks to everyone, England focuses on France, Germany is quiet
|
||||
|
||||
# France receives messages from multiple powers
|
||||
france_bot = bots["FRANCE"]
|
||||
france_bot.message_counts = {
|
||||
"ENGLAND": 5, # England is very active with France
|
||||
"GERMANY": 2, # Germany occasionally talks to France
|
||||
"ITALY": 1, # Italy sends one message
|
||||
}
|
||||
france_bot._update_priority_contacts()
|
||||
|
||||
# England receives mostly from France
|
||||
england_bot = bots["ENGLAND"]
|
||||
england_bot.message_counts = {
|
||||
"FRANCE": 8, # France talks to England a lot
|
||||
"GERMANY": 1, # Germany sends one message
|
||||
}
|
||||
england_bot._update_priority_contacts()
|
||||
|
||||
# Germany receives few messages
|
||||
germany_bot = bots["GERMANY"]
|
||||
germany_bot.message_counts = {
|
||||
"FRANCE": 3, # Some communication with France
|
||||
"ENGLAND": 1, # Minimal with England
|
||||
}
|
||||
germany_bot._update_priority_contacts()
|
||||
|
||||
# Check that priority contacts reflect communication patterns
|
||||
assert france_bot.priority_contacts[0] == "ENGLAND" # Most active with France
|
||||
assert england_bot.priority_contacts[0] == "FRANCE" # France is England's main contact
|
||||
assert germany_bot.priority_contacts[0] == "FRANCE" # France is Germany's main contact
|
||||
|
||||
# Check list lengths
|
||||
assert len(france_bot.priority_contacts) == 3 # Three powers contacted France
|
||||
assert len(england_bot.priority_contacts) == 2 # Two powers contacted England
|
||||
assert len(germany_bot.priority_contacts) == 2 # Two powers contacted Germany
|
||||
|
||||
def test_response_pattern_analysis(self, mock_bot_players):
|
||||
"""Test analysis of response patterns across multiple bots."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Set up different response patterns for each bot
|
||||
# France: Responsive to everyone
|
||||
france_bot = bots["FRANCE"]
|
||||
france_bot.message_counts = {"ENGLAND": 4, "GERMANY": 3}
|
||||
france_bot.response_counts = {"ENGLAND": 4, "GERMANY": 3} # 100% response rate
|
||||
|
||||
# England: Selective responder
|
||||
england_bot = bots["ENGLAND"]
|
||||
england_bot.message_counts = {"FRANCE": 6, "GERMANY": 2}
|
||||
england_bot.response_counts = {"FRANCE": 6, "GERMANY": 0} # Only responds to France
|
||||
|
||||
# Germany: Poor responder
|
||||
germany_bot = bots["GERMANY"]
|
||||
germany_bot.message_counts = {"FRANCE": 5, "ENGLAND": 3}
|
||||
germany_bot.response_counts = {"FRANCE": 2, "ENGLAND": 1} # Low response rates
|
||||
|
||||
# Mock client powers for statistics
|
||||
for power, bot in bots.items():
|
||||
mock_powers = {p: MagicMock(is_eliminated=lambda: False) for p in ["FRANCE", "ENGLAND", "GERMANY", "ITALY"]}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
# Generate statistics for each bot
|
||||
stats = {}
|
||||
for power, bot in bots.items():
|
||||
stats[power] = bot.get_message_statistics()
|
||||
|
||||
# Analyze response patterns
|
||||
france_stats = stats["FRANCE"]
|
||||
england_stats = stats["ENGLAND"]
|
||||
germany_stats = stats["GERMANY"]
|
||||
|
||||
# France should have high overall response rate
|
||||
france_avg_response = sum(france_stats["response_rate_by_power"].values()) / len(france_stats["response_rate_by_power"])
|
||||
assert france_avg_response == 1.0 # Perfect responder
|
||||
|
||||
# England should have selective response pattern
|
||||
assert england_stats["response_rate_by_power"]["FRANCE"] == 1.0 # Always responds to France
|
||||
assert england_stats["response_rate_by_power"]["GERMANY"] == 0.0 # Never responds to Germany
|
||||
|
||||
# Germany should have low overall response rate
|
||||
germany_avg_response = sum(germany_stats["response_rate_by_power"].values()) / len(germany_stats["response_rate_by_power"])
|
||||
assert germany_avg_response < 0.5 # Poor overall response rate
|
||||
|
||||
def test_message_history_consistency_across_bots(self, mock_bot_players):
|
||||
"""Test that message history tracking is consistent across different bots."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Simulate the same set of messages being processed by different bots
|
||||
# (as they would see them in a real game)
|
||||
|
||||
from diplomacy.engine.message import Message
|
||||
|
||||
shared_messages = [
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Alliance proposal", phase="S1901M"),
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="I accept", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="What about me?", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="GERMANY", message="You're welcome too", phase="S1901M"),
|
||||
]
|
||||
|
||||
# Each bot processes messages relevant to them
|
||||
for message in shared_messages:
|
||||
for power, bot in bots.items():
|
||||
# Add message to game history (all bots see all messages)
|
||||
bot.game_history.add_message(
|
||||
phase_name=message.phase,
|
||||
sender=message.sender,
|
||||
recipient=message.recipient,
|
||||
message_content=message.message,
|
||||
)
|
||||
|
||||
# Track messages directed at this bot
|
||||
if message.recipient == power and message.sender != power:
|
||||
bot.message_counts[message.sender] = bot.message_counts.get(message.sender, 0) + 1
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Verify that each bot has tracked messages correctly
|
||||
# France should have received 2 messages (from England and Germany)
|
||||
assert bots["FRANCE"].message_counts["ENGLAND"] == 1
|
||||
assert bots["FRANCE"].message_counts["GERMANY"] == 1
|
||||
|
||||
# England should have received 1 message (from France)
|
||||
assert bots["ENGLAND"].message_counts["FRANCE"] == 1
|
||||
|
||||
# Germany should have received 1 message (from France)
|
||||
assert bots["GERMANY"].message_counts["FRANCE"] == 1
|
||||
|
||||
# Check that game history is consistent (all bots see all 4 messages)
|
||||
for power, bot in bots.items():
|
||||
all_messages = bot.game_history.get_messages_for_phase("S1901M")
|
||||
assert len(all_messages) == 4, f"{power} should see all 4 messages in game history"
|
||||
433
bot_client/tests/test_negotiation_rounds.py
Normal file
433
bot_client/tests/test_negotiation_rounds.py
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
"""
|
||||
Integration tests for negotiation round functionality.
|
||||
|
||||
These tests verify that the negotiation system can properly:
|
||||
1. Conduct strategic negotiation rounds with message targeting
|
||||
2. Analyze recent messages for targeting decisions
|
||||
3. Handle negotiation timing and coordination
|
||||
4. Determine negotiation participation appropriately
|
||||
|
||||
Tests use mocked AI responses for predictable behavior while testing
|
||||
the real negotiation coordination logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from websocket_negotiations import (
|
||||
conduct_strategic_negotiation_round,
|
||||
analyze_recent_messages_for_targeting,
|
||||
should_participate_in_negotiations,
|
||||
get_negotiation_delay,
|
||||
)
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestMessageTargeting:
|
||||
"""Test strategic message targeting analysis."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_with_messages(self, fake_server):
|
||||
"""Client with pre-populated message history for targeting tests."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("targeting_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=7, # Full game for realistic targeting
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Simulate message history with different activity levels
|
||||
# Note: In a real scenario, these would come from other players
|
||||
# For testing, we'll add them to the game's message history directly
|
||||
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_analyze_recent_messages_empty_history(self, client_with_messages):
|
||||
"""Test targeting analysis with no message history."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get targeting priority with empty message history
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
|
||||
|
||||
# Should return all active powers (excluding FRANCE itself)
|
||||
assert isinstance(targets, list)
|
||||
# Should not include FRANCE itself
|
||||
assert "FRANCE" not in targets
|
||||
# Should include other major powers
|
||||
expected_powers = {"ENGLAND", "GERMANY", "ITALY", "AUSTRIA", "RUSSIA", "TURKEY"}
|
||||
assert len(set(targets) & expected_powers) > 0
|
||||
|
||||
async def test_analyze_recent_messages_with_activity(self, client_with_messages):
|
||||
"""Test targeting analysis with simulated message activity."""
|
||||
client = client_with_messages
|
||||
|
||||
# Simulate some message activity by sending messages
|
||||
# This tests the actual message retrieval and analysis logic
|
||||
|
||||
# Send messages to create activity patterns
|
||||
await client.send_message("ENGLAND", "FRANCE", "Hello France!")
|
||||
await client.send_message("ENGLAND", "FRANCE", "Another message from England")
|
||||
await client.send_message("GERMANY", "FRANCE", "Message from Germany")
|
||||
await client.send_message("ITALY", "GLOBAL", "Global message from Italy")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await client.synchronize()
|
||||
|
||||
# Now analyze targeting
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
|
||||
|
||||
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
|
||||
# GERMANY should be second (sent 1 direct message to FRANCE)
|
||||
# Others should follow
|
||||
assert isinstance(targets, list)
|
||||
assert "FRANCE" not in targets
|
||||
|
||||
# England should be first in priority (most direct messages to FRANCE)
|
||||
if "ENGLAND" in targets:
|
||||
england_index = targets.index("ENGLAND")
|
||||
if "GERMANY" in targets:
|
||||
germany_index = targets.index("GERMANY")
|
||||
assert england_index < germany_index, "England should be prioritized over Germany"
|
||||
|
||||
async def test_analyze_messages_error_handling(self, client_with_messages):
|
||||
"""Test error handling in message analysis."""
|
||||
client = client_with_messages
|
||||
|
||||
# Test with invalid power name
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="INVALID_POWER", max_messages=20)
|
||||
|
||||
# Should still return a valid list (fallback behavior)
|
||||
assert isinstance(targets, list)
|
||||
|
||||
|
||||
class TestNegotiationTiming:
|
||||
"""Test negotiation delay and timing logic."""
|
||||
|
||||
def test_negotiation_delay_calculation(self):
|
||||
"""Test that negotiation delays are calculated correctly."""
|
||||
|
||||
# Test first round gets extra time
|
||||
first_round_delay = get_negotiation_delay(round_number=1, total_rounds=3)
|
||||
base_delay = get_negotiation_delay(round_number=2, total_rounds=3)
|
||||
final_round_delay = get_negotiation_delay(round_number=3, total_rounds=3)
|
||||
|
||||
assert first_round_delay > base_delay, "First round should have longer delay"
|
||||
assert final_round_delay < base_delay, "Final round should have shorter delay"
|
||||
|
||||
# All delays should be positive
|
||||
assert first_round_delay > 0
|
||||
assert base_delay > 0
|
||||
assert final_round_delay > 0
|
||||
|
||||
def test_negotiation_delay_edge_cases(self):
|
||||
"""Test delay calculation with edge cases."""
|
||||
|
||||
# Single round
|
||||
single_delay = get_negotiation_delay(round_number=1, total_rounds=1)
|
||||
assert single_delay > 0
|
||||
|
||||
# Many rounds
|
||||
many_rounds_delay = get_negotiation_delay(round_number=5, total_rounds=10)
|
||||
assert many_rounds_delay > 0
|
||||
|
||||
|
||||
class TestNegotiationParticipation:
|
||||
"""Test logic for determining negotiation participation."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_and_mocked_agent(self, fake_server):
|
||||
"""Client with mocked agent for participation testing."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("participation_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="RUSSIA",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Create mock agent
|
||||
mock_agent = MagicMock(spec=DiplomacyAgent)
|
||||
mock_agent.power_name = "RUSSIA"
|
||||
|
||||
yield {"client": client, "agent": mock_agent}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_should_participate_eliminated_power(self, client_and_mocked_agent):
|
||||
"""Test that eliminated powers don't participate in negotiations."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock the power as eliminated
|
||||
with patch.object(client, "get_power") as mock_get_power:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = True
|
||||
mock_get_power.return_value = mock_power
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_no_orders(self, client_and_mocked_agent):
|
||||
"""Test that powers with no orderable locations don't negotiate."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock the power as not eliminated but with no orderable locations
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = [] # No possible orders
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_non_movement_phase(self, client_and_mocked_agent):
|
||||
"""Test that powers don't negotiate in non-movement phases."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock retreat phase
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch.object(client, "get_current_short_phase") as mock_phase,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Moscow - Hold"] # Has orders
|
||||
mock_phase.return_value = "S1901R" # Retreat phase
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_movement_phase_with_orders(self, client_and_mocked_agent):
|
||||
"""Test that active powers with orders participate in movement phases."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock movement phase with orderable locations
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch.object(client, "get_current_short_phase") as mock_phase,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Moscow - Hold", "F Sevastopol - Black Sea"]
|
||||
mock_phase.return_value = "S1901M" # Movement phase
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is True
|
||||
|
||||
|
||||
class TestStrategicNegotiationRound:
|
||||
"""Test the complete strategic negotiation round functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def negotiation_setup(self, fake_server):
|
||||
"""Setup for testing complete negotiation rounds."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("negotiation_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="TURKEY",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Create mock agent with necessary attributes
|
||||
mock_agent = MagicMock(spec=DiplomacyAgent)
|
||||
mock_agent.power_name = "TURKEY"
|
||||
mock_agent.goals = "Expand into the Mediterranean"
|
||||
mock_agent.relationships = {}
|
||||
mock_agent.format_private_diary_for_prompt.return_value = "Test diary entry"
|
||||
|
||||
# Mock AI client for message generation
|
||||
mock_ai_client = AsyncMock()
|
||||
mock_agent.client = mock_ai_client
|
||||
|
||||
# Create game history
|
||||
game_history = GameHistory()
|
||||
|
||||
# Mock error stats
|
||||
error_stats = {"test_model": {"conversation_errors": 0, "order_decoding_errors": 0}}
|
||||
|
||||
yield {"client": client, "agent": mock_agent, "game_history": game_history, "error_stats": error_stats}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_successful_negotiation_round(self, negotiation_setup):
|
||||
"""Test a successful negotiation round with message generation."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI response for message generation
|
||||
mock_messages = [
|
||||
{"content": "Greetings! Turkey seeks peaceful relations.", "message_type": "global"},
|
||||
{"content": "Russia, shall we coordinate our efforts?", "message_type": "private", "recipient": "RUSSIA"},
|
||||
]
|
||||
agent.client.get_conversation_reply.return_value = mock_messages
|
||||
|
||||
# Mock power and orders
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
patch.object(client, "send_message") as mock_send,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
# Should have called send_message for each generated message
|
||||
assert mock_send.call_count == len(mock_messages)
|
||||
|
||||
async def test_negotiation_round_no_messages(self, negotiation_setup):
|
||||
"""Test negotiation round when AI generates no messages."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI response with no messages
|
||||
agent.client.get_conversation_reply.return_value = []
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
|
||||
async def test_negotiation_round_ai_error(self, negotiation_setup):
|
||||
"""Test negotiation round when AI client raises an exception."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI client to raise exception
|
||||
agent.client.get_conversation_reply.side_effect = Exception("AI service unavailable")
|
||||
agent.client.model_name = "test_model"
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
# Error should be tracked in statistics
|
||||
assert error_stats["test_model"]["conversation_errors"] == 1
|
||||
|
||||
async def test_negotiation_round_with_targeting(self, negotiation_setup):
|
||||
"""Test that negotiation round uses strategic targeting."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock message targeting analysis
|
||||
with patch("websocket_negotiations.analyze_recent_messages_for_targeting") as mock_targeting:
|
||||
mock_targeting.return_value = ["RUSSIA", "AUSTRIA", "ITALY"]
|
||||
|
||||
# Mock AI response
|
||||
mock_messages = [{"content": "Test message", "message_type": "global"}]
|
||||
agent.client.get_conversation_reply.return_value = mock_messages
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round 2 (should use targeting)
|
||||
await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=2, # Round > 1 should use targeting
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
# Should have called the targeting analysis
|
||||
mock_targeting.assert_called_once_with(client, "TURKEY")
|
||||
|
|
@ -19,16 +19,14 @@ from models import (
|
|||
serialize_message,
|
||||
parse_message,
|
||||
)
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Test authentication flow and message handling."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_successful_authentication(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
async def test_successful_authentication(self, client: WebSocketDiplomacyClient, helpers):
|
||||
"""Test successful authentication with valid credentials."""
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
helpers.assert_valid_token(client.token)
|
||||
|
|
@ -44,16 +42,12 @@ class TestAuthentication:
|
|||
assert client.token is None
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_authentication_message_format(
|
||||
self, credentials: Dict[str, str], client: TypedWebSocketDiplomacyClient
|
||||
):
|
||||
async def test_authentication_message_format(self, credentials: Dict[str, str], client: WebSocketDiplomacyClient):
|
||||
"""Test that authentication messages are properly formatted."""
|
||||
await client.connect_and_authenticate(**credentials)
|
||||
|
||||
# Create a sign-in request manually to test message format
|
||||
request = SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test_user", password="test_password"
|
||||
)
|
||||
request = SignInRequest(request_id=str(uuid.uuid4()), username="test_user", password="test_password")
|
||||
|
||||
# Verify the request serializes correctly
|
||||
message_data = serialize_message(request)
|
||||
|
|
@ -74,9 +68,7 @@ class TestGameOperations:
|
|||
@pytest.mark.websocket
|
||||
async def test_create_game(self, authenticated_client, helpers):
|
||||
"""Test creating a new game."""
|
||||
game_data = await authenticated_client.create_game(
|
||||
power_name="FRANCE", n_controls=1
|
||||
)
|
||||
game_data = await authenticated_client.create_game(power_name="FRANCE", n_controls=1)
|
||||
|
||||
helpers.assert_valid_game_data(game_data)
|
||||
assert authenticated_client.game_id is not None
|
||||
|
|
@ -98,9 +90,7 @@ class TestGameOperations:
|
|||
assert authenticated_client.game_id in game_ids
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_join_existing_game(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
async def test_join_existing_game(self, client: WebSocketDiplomacyClient, helpers):
|
||||
"""Test joining an existing game."""
|
||||
# Connect and authenticate a first client
|
||||
await client.connect()
|
||||
|
|
@ -130,9 +120,7 @@ class TestGameOperations:
|
|||
async def test_join_nonexistent_game(self, authenticated_client):
|
||||
"""Test joining a game that doesn't exist."""
|
||||
with pytest.raises(ValueError, match="Game .* not found"):
|
||||
await authenticated_client.join_game(
|
||||
"NONEXISTENT_GAME", power_name="FRANCE"
|
||||
)
|
||||
await authenticated_client.join_game("NONEXISTENT_GAME", power_name="FRANCE")
|
||||
|
||||
|
||||
class TestGamePlay:
|
||||
|
|
@ -246,9 +234,7 @@ class TestMessageValidation:
|
|||
"""Test that messages can be serialized and parsed correctly."""
|
||||
# Test various message types
|
||||
messages = [
|
||||
SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test", password="pass"
|
||||
),
|
||||
SignInRequest(request_id=str(uuid.uuid4()), username="test", password="pass"),
|
||||
CreateGameRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token="test_token",
|
||||
|
|
@ -289,7 +275,7 @@ class TestConcurrentOperations:
|
|||
try:
|
||||
# Create multiple clients
|
||||
for i in range(3):
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
await client.connect()
|
||||
await client.authenticate("test_user", "test_password")
|
||||
clients.append(client)
|
||||
|
|
@ -300,9 +286,7 @@ class TestConcurrentOperations:
|
|||
|
||||
# Each can create games independently
|
||||
for i, client in enumerate(clients):
|
||||
game_data = await client.create_game(
|
||||
power_name="FRANCE" if i == 0 else None, n_controls=1
|
||||
)
|
||||
game_data = await client.create_game(power_name="FRANCE" if i == 0 else None, n_controls=1)
|
||||
assert game_data["game_id"] is not None
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -36,3 +36,4 @@ dependencies = [
|
|||
[tool.ruff]
|
||||
line-length = 150
|
||||
exclude = ["diplomacy", "docs"]
|
||||
line-length = 150
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue