WIP: Continual work on the bot client. They successfully play the game, but don't negotiate

Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
Tyler Marques 2025-07-01 10:36:54 -07:00
parent c09d5bfe7e
commit 79a31d615d
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 2483 additions and 263 deletions

View file

@ -6,7 +6,7 @@ import warnings
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path
log_file_path: Path | None = None
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None

View file

@ -72,7 +72,7 @@ class MultiBotLauncher:
logger.info("Creating new game...")
# Connect as the game creator
creator_username = f"{self.base_username}_{creator_power.lower()}"
creator_username = f"{self.base_username}_{creator_power}"
client = await connect_to_diplomacy_server(
hostname=self.hostname,
port=self.port,
@ -105,6 +105,9 @@ class MultiBotLauncher:
game_id: str,
log_level: str = "INFO",
negotiation_rounds: int = 3,
connection_timeout: float = 30.0,
max_retries: int = 3,
retry_delay: float = 2.0,
) -> subprocess.Popen:
"""
Launch a single bot process.
@ -141,6 +144,12 @@ class MultiBotLauncher:
log_level,
"--negotiation-rounds",
str(negotiation_rounds),
"--connection-timeout",
str(connection_timeout),
"--max-retries",
str(max_retries),
"--retry-delay",
str(retry_delay),
]
logger.info(f"Launching bot for {power} with model {model}")
@ -165,6 +174,9 @@ class MultiBotLauncher:
log_level: str = "INFO",
stagger_delay: float = 0.5,
negotiation_rounds: int = 3,
connection_timeout: float = 30.0,
max_retries: int = 3,
retry_delay: float = 2.0,
):
"""
Launch bots for all specified powers.
@ -188,7 +200,7 @@ class MultiBotLauncher:
model = models.get(power, "gpt-3.5-turbo")
try:
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds)
process = self.launch_bot(power, model, game_id, log_level, negotiation_rounds, connection_timeout, max_retries, retry_delay)
self.bot_processes.append(process)
self.process_to_power[process] = power
@ -266,7 +278,7 @@ class MultiBotLauncher:
break
print(f"{power}_{process.pid}: {line.strip()}")
lines_read += 1
except:
except Exception:
break
except (OSError, ValueError):
@ -299,7 +311,7 @@ class MultiBotLauncher:
break
print(f"{power}_{process.pid}: {line.strip()}")
lines_read += 1
except:
except Exception:
break
def stop_all_bots(self):
@ -329,6 +341,9 @@ class MultiBotLauncher:
log_level: str = "INFO",
creator_power: str = "FRANCE",
negotiation_rounds: int = 3,
connection_timeout: float = 30.0,
max_retries: int = 3,
retry_delay: float = 2.0,
):
"""
Create a game and launch all bots for a complete game.
@ -352,6 +367,9 @@ class MultiBotLauncher:
models,
log_level=log_level,
negotiation_rounds=negotiation_rounds,
connection_timeout=connection_timeout,
max_retries=max_retries,
retry_delay=retry_delay,
)
# Monitor the bots
@ -369,6 +387,9 @@ class MultiBotLauncher:
models: Optional[Dict[str, str]] = None,
log_level: str = "INFO",
negotiation_rounds: int = 3,
connection_timeout: float = 30.0,
max_retries: int = 3,
retry_delay: float = 2.0,
):
"""
Launch bots to join an existing game.
@ -389,6 +410,9 @@ class MultiBotLauncher:
powers,
log_level,
negotiation_rounds=negotiation_rounds,
connection_timeout=connection_timeout,
max_retries=max_retries,
retry_delay=retry_delay,
)
# Monitor the bots
@ -419,12 +443,33 @@ def parse_arguments():
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
)
parser.add_argument(
"--connection-timeout",
type=float,
default=30.0,
help="Timeout for network operations in seconds (default: 30.0)",
)
parser.add_argument(
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for failed operations (default: 3)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Base delay between retries in seconds (default: 2.0)",
)
return parser.parse_args()
async def main():
"""Main entry point."""
# FIXME: Arg parse appears to not like game ids with hypens in the name. e.g.
# uv run python multi_bot_launcher.py --game-id "-1D0i-fobmvprIh1" results in an error
args = parse_arguments()
launcher = MultiBotLauncher(
@ -454,6 +499,9 @@ async def main():
models=models,
log_level=args.log_level,
negotiation_rounds=args.negotiation_rounds,
connection_timeout=args.connection_timeout,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
else:
# Create new game and launch all bots
@ -462,6 +510,9 @@ async def main():
log_level=args.log_level,
creator_power=args.creator_power,
negotiation_rounds=args.negotiation_rounds,
connection_timeout=args.connection_timeout,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
except KeyboardInterrupt:

View file

@ -13,6 +13,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import argparse
import asyncio
import signal
import time
from typing import Optional, Dict, List
import dotenv
from loguru import logger
@ -21,8 +22,7 @@ from loguru import logger
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
from diplomacy.engine.power import Power
from diplomacy.utils.exceptions import GameIdException
from diplomacy.utils.exceptions import GameIdException, DiplomacyException
from diplomacy.communication.notifications import GameStatusUpdate
from diplomacy.engine.message import Message
@ -41,7 +41,6 @@ from websocket_negotiations import (
dotenv.load_dotenv()
# TODO: This, but better
config: Configuration = None
class SingleBotPlayer:
@ -64,6 +63,9 @@ class SingleBotPlayer:
port: int = 8432,
game_id: Optional[str] = None,
negotiation_rounds: int = 3,
connection_timeout: float = 30.0,
max_retries: int = 3,
retry_delay: float = 2.0,
):
assert username is not None
assert password is not None
@ -77,6 +79,7 @@ class SingleBotPlayer:
self.power_name = power_name
self.model_name = model_name
self.game_id = game_id
self.config = Configuration(power_name=power_name)
# Bot state
self.client: WebSocketDiplomacyClient
@ -100,6 +103,21 @@ class SingleBotPlayer:
self.response_counts: Dict[str, int] = {} # Responses sent to each power
self.priority_contacts: List[str] = [] # Powers to prioritize for communication
# Connection health and fault tolerance (configurable)
self.connection_timeout = connection_timeout
self.retry_delay = retry_delay
self.max_retries = max_retries
self.last_successful_operation = time.time()
self.connection_failures = 0
self.circuit_breaker_open = False
self.circuit_breaker_last_failure = 0
self.circuit_breaker_timeout = 60.0 # 1 minute before trying again
# Add graceful shutdown flag
self.shutdown_requested = False
logger.info(f"Fault tolerance config: timeout={connection_timeout}s, max_retries={max_retries}, retry_delay={retry_delay}s")
# Setup signal handlers for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
@ -109,12 +127,80 @@ class SingleBotPlayer:
logger.info(f"Received signal {signum}, shutting down...")
self.running = False
def _is_circuit_breaker_open(self) -> bool:
"""Check if circuit breaker is open (preventing operations due to failures)."""
if not self.circuit_breaker_open:
return False
# Check if timeout has passed and we should try again
if time.time() - self.circuit_breaker_last_failure > self.circuit_breaker_timeout:
logger.info("Circuit breaker timeout expired, allowing operations")
self.circuit_breaker_open = False
return False
return True
def _record_operation_success(self):
"""Record a successful operation."""
self.last_successful_operation = time.time()
self.connection_failures = 0
if self.circuit_breaker_open:
logger.info("Operation successful, closing circuit breaker")
self.circuit_breaker_open = False
def _record_operation_failure(self):
"""Record a failed operation and potentially open circuit breaker."""
self.connection_failures += 1
logger.warning(f"Operation failed, failure count: {self.connection_failures}")
if self.connection_failures >= 5: # Open circuit after 5 consecutive failures
logger.error("Opening circuit breaker due to repeated failures")
self.circuit_breaker_open = True
self.circuit_breaker_last_failure = time.time()
async def _retry_with_backoff(self, operation, *args, **kwargs):
"""Execute an operation with exponential backoff retry logic."""
if self._is_circuit_breaker_open():
raise DiplomacyException("Circuit breaker is open, operation not allowed")
last_exception = None
for attempt in range(self.max_retries + 1):
try:
result = await asyncio.wait_for(operation(*args, **kwargs), timeout=self.connection_timeout)
self._record_operation_success()
return result
except (TimeoutError, asyncio.TimeoutError) as e:
last_exception = e
logger.warning(f"Operation timeout on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
except (ConnectionError, DiplomacyException) as e:
last_exception = e
logger.warning(f"Connection error on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
except Exception as e:
last_exception = e
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries + 1}: {e}")
# Don't delay after the last attempt
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
logger.info(f"Retrying in {delay:.1f} seconds...")
await asyncio.sleep(delay)
# All retries failed
self._record_operation_failure()
logger.error(f"Operation failed after {self.max_retries + 1} attempts")
raise last_exception or Exception("Operation failed with unknown error")
async def connect_and_initialize(self):
"""Connect to the server and initialize the bot."""
logger.info(f"Connecting to {self.hostname}:{self.port} as {self.username}")
# Connect to server
self.client = await connect_to_diplomacy_server(
# Connect to server with retry logic
self.client = await self._retry_with_backoff(
connect_to_diplomacy_server,
hostname=self.hostname,
port=self.port,
username=self.username,
@ -142,7 +228,7 @@ class SingleBotPlayer:
self.agent = DiplomacyAgent(power_name=self.power_name, client=model_client)
# Initialize agent state
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, config.log_file_path)
await initialize_agent_state_ext(self.agent, self.client.game, self.game_history, self.config.log_file_path)
# Setup game event callbacks
await self._setup_event_callbacks()
@ -187,8 +273,14 @@ class SingleBotPlayer:
async def _handle_phase_update_async(self, notification):
"""Async handler for phase updates."""
# Update our game state
await self.client.game.synchronize()
try:
# Update our game state with retry logic
await self._retry_with_backoff(self.client.game.synchronize)
except Exception as e:
# This is a critical error. If we cannot synchronize the game, even with backoffs, we shouldn't continue.
logger.critical(f"Failed to synchronize game state during phase update: {e}")
raise e
new_phase = self.client.game.get_current_phase()
if new_phase != self.current_phase:
@ -219,14 +311,20 @@ class SingleBotPlayer:
async def _handle_game_processed_async(self):
"""Async handler for game processing."""
# Synchronize to get the results
await self.client.game.synchronize()
try:
# Synchronize to get the results with retry logic
await self._retry_with_backoff(self.client.game.synchronize)
# Analyze the results
await self._analyze_phase_results()
# Analyze the results
await self._analyze_phase_results()
self.orders_submitted = False
self.waiting_for_orders = False
self.orders_submitted = False
self.waiting_for_orders = False
except Exception as e:
logger.error(f"Failed to handle game processing: {e}")
# Reset state even if synchronization failed
self.orders_submitted = False
self.waiting_for_orders = False
def _on_message_received(self, game, notification):
"""Handle incoming diplomatic messages."""
@ -289,7 +387,7 @@ class SingleBotPlayer:
return
if action_type == "build":
# Get buildable locations for this power
possible_orders = gather_possible_order(self.client.game, self.power_name)
possible_orders = gather_possible_orders(self.client.game, self.power_name)
if not possible_orders:
logger.warning(f"No possible build orders for {self.power_name}")
await self.client.set_orders(self.power_name, [])
@ -309,6 +407,16 @@ class SingleBotPlayer:
elif action_type == "disband":
# Get current units for disbanding
power = self.client.game.get_power(self.power_name)
current_units = list(power.units.keys())
# Create disband orders
disband_orders = []
for i, unit_location in enumerate(current_units[:count]):
# Format: "A Berlin - DISBAND" or "F London - DISBAND"
unit_type = power.units[unit_location][0] # 'A' or 'F'
disband_order = f"{unit_type} {unit_location} - DISBAND"
disband_orders.append(disband_order)
logger.info(f"Submitting disband orders for {self.power_name}: {disband_orders}")
await self.client.set_orders(self.power_name, disband_orders)
@ -323,60 +431,74 @@ class SingleBotPlayer:
logger.debug("Orders already submitted for this phase")
return
current_phase = self.client.game.get_current_phase()
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
try:
current_phase = self.client.game.get_current_phase()
logger.info(f"Generating orders for {self.power_name} in phase {current_phase}...")
# Get current board state
board_state = self.client.game.get_state()
# Get current board state
board_state = self.client.game.get_state()
# Get possible orders
possible_orders = gather_possible_orders(self.client.game, self.power_name)
# Get possible orders
possible_orders = gather_possible_orders(self.client.game, self.power_name)
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
logger.debug(f"Possible orders for {self.power_name}: {possible_orders}")
if not possible_orders:
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
return
if not possible_orders:
logger.info(f"No possible orders for {self.power_name}, submitting empty order set")
await self._retry_with_backoff(self.client.set_orders, self.power_name, [])
self.orders_submitted = True
return
# Generate orders using AI
orders = await get_valid_orders(
game=self.client.game,
client=self.agent.client,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
model_error_stats=self.error_stats,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
phase=self.client.game.get_current_phase(),
)
# Submit orders
if orders:
logger.info(f"Submitting orders: {orders}")
await self.client.set_orders(self.power_name, orders)
# Generate order diary entry
await self.agent.generate_order_diary_entry(
self.client.game,
orders,
config.log_file_path,
# Generate orders using AI
orders = await get_valid_orders(
game=self.client.game,
client=self.agent.client,
board_state=board_state,
power_name=self.power_name,
possible_orders=possible_orders,
game_history=self.game_history,
model_error_stats=self.error_stats,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
agent_private_diary_str=self.agent.format_private_diary_for_prompt(),
phase=self.client.game.get_current_phase(),
)
else:
logger.info("No valid orders generated, submitting empty order set")
await self.client.set_orders(self.power_name, [])
self.orders_submitted = True
self.waiting_for_orders = False
logger.info("Orders submitted successfully")
# Call the no wait so we don't sit around for the turns to end.
# TODO: We probably don't want to call this here.
# We want to call it when negotiations end,
self.client.game.no_wait()
# Submit orders with retry logic
if orders:
logger.info(f"Submitting orders: {orders}")
await self._retry_with_backoff(self.client.set_orders, self.power_name, orders)
# Generate order diary entry (don't retry this if it fails)
try:
await self.agent.generate_order_diary_entry(
self.client.game,
orders,
self.config.log_file_path,
)
except Exception as diary_error:
logger.warning(f"Failed to generate order diary entry: {diary_error}")
else:
logger.info("No valid orders generated, submitting empty order set")
await self._retry_with_backoff(self.client.set_orders, self.power_name, [])
self.orders_submitted = True
self.waiting_for_orders = False
logger.info("Orders submitted successfully")
# Call the no wait so we don't sit around for the turns to end.
# TODO: We probably don't want to call this here.
# We want to call it when negotiations end,
try:
self.client.game.no_wait()
except Exception as no_wait_error:
logger.warning(f"Failed to call no_wait: {no_wait_error}")
except Exception as e:
logger.error(f"Failed to submit orders: {e}", exc_info=True)
# Mark as submitted to avoid infinite retry loops
self.orders_submitted = True
self.waiting_for_orders = False
async def _analyze_phase_results(self):
"""Analyze the results of the previous phase."""
@ -394,7 +516,7 @@ class SingleBotPlayer:
board_state=board_state,
phase_summary=phase_summary,
game_history=self.game_history,
log_file_path=config.log_file_path,
log_file_path=self.config.log_file_path,
)
logger.info("Phase analysis complete")
@ -422,7 +544,7 @@ class SingleBotPlayer:
agent=self.agent,
game_history=self.game_history,
model_error_stats=self.error_stats,
log_file_path=config.log_file_path,
log_file_path=self.config.log_file_path,
round_number=round_num,
max_rounds=self.negotiation_rounds,
)
@ -505,7 +627,7 @@ class SingleBotPlayer:
possible_orders=possible_orders,
game_history=self.game_history,
game_phase=self.client.get_current_short_phase(),
log_file_path=config.log_file_path,
log_file_path=self.config.log_file_path,
active_powers=active_powers,
agent_goals=self.agent.goals,
agent_relationships=self.agent.relationships,
@ -516,12 +638,17 @@ class SingleBotPlayer:
if responses and len(responses) > 0:
response_content = responses[0].get("content", "").strip()
if response_content:
await self.client.send_message(
sender=self.power_name,
recipient=message.sender,
message=response_content,
phase=self.client.get_current_short_phase(),
)
try:
await self._retry_with_backoff(
self.client.send_message,
sender=self.power_name,
recipient=message.sender,
message=response_content,
phase=self.client.get_current_short_phase(),
)
except Exception as send_error:
logger.warning(f"Failed to send message response: {send_error}")
return # Don't record the message if sending failed
# Add to game history
self.game_history.add_message(
@ -607,14 +734,31 @@ class SingleBotPlayer:
# Main event loop
while self.running and not self.client.game.is_game_done:
# Synchronize with server periodically
await self.client.game.synchronize()
try:
# Synchronize with server periodically with retry logic
await self._retry_with_backoff(self.client.game.synchronize)
# Check if we need to submit orders
await self._check_if_orders_needed()
# Check if we need to submit orders
await self._check_if_orders_needed()
# Sleep for a bit before next iteration
await asyncio.sleep(5)
# Sleep for a bit before next iteration
await asyncio.sleep(5)
except (asyncio.CancelledError, KeyboardInterrupt):
logger.info("Bot operation cancelled, shutting down")
break
except (TimeoutError, asyncio.TimeoutError) as e:
logger.warning(f"Timeout in main loop: {e}")
# Continue loop but with a longer sleep
await asyncio.sleep(10)
except Exception as e:
logger.error(f"Error in main loop: {e}", exc_info=True)
# Continue running unless it's a critical error
if "circuit breaker" in str(e).lower():
logger.error("Circuit breaker opened, waiting before retry")
await asyncio.sleep(30)
else:
await asyncio.sleep(5)
if self.client.game.is_game_done:
logger.info("Game has finished")
@ -622,21 +766,94 @@ class SingleBotPlayer:
logger.info("Bot shutting down")
except GameIdException:
logger.error(f"Game with id {self.game_id} does not exist on the server. Exiting...")
except (asyncio.CancelledError, KeyboardInterrupt):
logger.info("Bot cancelled or interrupted")
except Exception as e:
logger.error(f"Fatal error in bot: {e}", exc_info=True)
finally:
await self.cleanup()
async def cleanup(self):
"""Clean up resources."""
"""Clean up resources with timeout protection."""
logger.info("Starting cleanup process...")
cleanup_timeout = 15.0 # Maximum time to spend on cleanup
try:
if self.client.game:
# We need to leave a game here, otherwise the connection stays for some reason and we can never rejoin
self.client.game.leave()
if self.client:
await self.client.close()
logger.info("Cleanup complete")
# Use asyncio.wait_for to prevent hanging during cleanup
await asyncio.wait_for(self._perform_cleanup(), timeout=cleanup_timeout)
logger.info("Cleanup completed successfully")
except asyncio.TimeoutError:
logger.warning(f"Cleanup timed out after {cleanup_timeout} seconds")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
async def _perform_cleanup(self):
"""Perform the actual cleanup operations."""
cleanup_tasks = []
# Game cleanup
if hasattr(self, "client") and self.client and hasattr(self.client, "game") and self.client.game:
logger.debug("Cleaning up game connection...")
try:
# Use asyncio.create_task to make game.leave() non-blocking
leave_task = asyncio.create_task(self._safe_game_leave())
cleanup_tasks.append(leave_task)
except Exception as e:
logger.warning(f"Error creating game leave task: {e}")
# Client cleanup
if hasattr(self, "client") and self.client:
logger.debug("Cleaning up client connection...")
try:
close_task = asyncio.create_task(self._safe_client_close())
cleanup_tasks.append(close_task)
except Exception as e:
logger.warning(f"Error creating client close task: {e}")
# Wait for all cleanup tasks with individual timeouts
if cleanup_tasks:
done, pending = await asyncio.wait(
cleanup_tasks,
timeout=10.0, # 10 second timeout for all cleanup tasks
return_when=asyncio.ALL_COMPLETED,
)
# Cancel any pending tasks
for task in pending:
logger.warning(f"Cancelling pending cleanup task: {task}")
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning(f"Error cancelling cleanup task: {e}")
async def _safe_game_leave(self):
"""Safely leave the game with timeout."""
try:
# Some diplomacy client implementations have async leave, others are sync
if asyncio.iscoroutinefunction(self.client.game.leave):
await asyncio.wait_for(self.client.game.leave(), timeout=5.0)
else:
# Run synchronous leave in a thread to avoid blocking
await asyncio.get_event_loop().run_in_executor(None, self.client.game.leave)
logger.debug("Successfully left game")
except asyncio.TimeoutError:
logger.warning("Game leave operation timed out")
except Exception as e:
logger.warning(f"Error leaving game: {e}")
async def _safe_client_close(self):
"""Safely close the client with timeout."""
try:
await asyncio.wait_for(self.client.close(), timeout=5.0)
logger.debug("Successfully closed client")
except asyncio.TimeoutError:
logger.warning("Client close operation timed out")
except Exception as e:
logger.warning(f"Error closing client: {e}")
def parse_arguments():
"""Parse command line arguments."""
@ -656,28 +873,63 @@ def parse_arguments():
default=3,
help="Number of negotiation rounds per movement phase (default: 3)",
)
parser.add_argument(
"--connection-timeout",
type=float,
default=30.0,
help="Timeout for network operations in seconds (default: 30.0)",
)
parser.add_argument(
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for failed operations (default: 3)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Base delay between retries in seconds (default: 2.0)",
)
return parser.parse_args()
async def main():
"""Main entry point."""
args = parse_arguments()
if not args.username:
args.username = f"bot_{args.power}"
"""Main entry point with comprehensive error handling."""
bot = None
try:
args = parse_arguments()
if not args.username:
args.username = f"bot_{args.power}"
bot = SingleBotPlayer(
hostname=args.hostname,
port=args.port,
username=args.username,
password=args.password,
power_name=args.power,
model_name=args.model,
game_id=args.game_id,
negotiation_rounds=args.negotiation_rounds,
)
config = Configuration(args.power)
await bot.run()
bot = SingleBotPlayer(
hostname=args.hostname,
port=args.port,
username=args.username,
password=args.password,
power_name=args.power,
model_name=args.model,
game_id=args.game_id,
negotiation_rounds=args.negotiation_rounds,
connection_timeout=args.connection_timeout,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
await bot.run()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt")
except Exception as e:
logger.error(f"Fatal error in main: {e}", exc_info=True)
finally:
if bot:
# Ensure cleanup happens even if there was an error
try:
await bot.cleanup()
except Exception as cleanup_error:
logger.error(f"Error during final cleanup: {cleanup_error}")
if __name__ == "__main__":

View file

@ -1 +1 @@
# Test package initialization
# Test package initialization

View file

@ -18,7 +18,7 @@ import pytest_asyncio
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from .test_fake_websocket_server import FakeServerManager, FakeWebSocketServer
from typed_websocket_client import TypedWebSocketDiplomacyClient
from websocket_diplomacy_client import WebSocketDiplomacyClient
@pytest_asyncio.fixture(scope="function", autouse=True)
@ -39,13 +39,13 @@ def credentials() -> Generator[Dict[str, str], None]:
@pytest_asyncio.fixture
async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
async def client() -> AsyncGenerator[WebSocketDiplomacyClient, None]:
"""
Fixture that provides a TypedWebSocketDiplomacyClient instance.
Fixture that provides a WebSocketDiplomacyClient instance.
The client is configured to connect to the fake server on port 8433.
"""
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
yield client
# Cleanup: close the client connection
@ -58,13 +58,13 @@ async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
@pytest.fixture
async def authenticated_client(fake_server, client):
"""
Fixture that provides an authenticated TypedWebSocketDiplomacyClient.
Fixture that provides an authenticated WebSocketDiplomacyClient.
This client is already connected and authenticated, ready for testing
game operations.
"""
await client.connect()
token = await client.authenticate("test_user", "test_password")
await client.connect_and_authenticate("test_user", "test_password")
token = client.token
assert token is not None
assert client.token == token
yield client

View file

@ -38,49 +38,41 @@ from models import (
class FakeWebSocketServer:
"""
A fake WebSocket server that implements the Diplomacy protocol for testing.
This server maintains minimal state and responds to requests with valid
responses according to the WEBSOCKET.md protocol specification.
"""
def __init__(self, host: str = "localhost", port: int = 8433):
self.host = host
self.port = port
self.server = None
self.connected_clients: Set[websockets.WebSocketServerProtocol] = set()
# Mock server state
self.users = {
"test_user": "test_password",
"ai_player": "password",
"player1": "password"
}
self.users = {"test_user": "test_password", "ai_player": "password", "player1": "password"}
self.tokens: Dict[str, str] = {} # token -> username
self.games: Dict[str, Dict[str, Any]] = {}
self.game_counter = 1
async def start(self):
"""Start the fake WebSocket server."""
logger.info(f"Starting fake WebSocket server on {self.host}:{self.port}")
self.server = await websockets.serve(
self.handle_client,
self.host,
self.port
)
self.server = await websockets.serve(self.handle_client, self.host, self.port)
logger.info("Fake WebSocket server started")
async def stop(self):
"""Stop the fake WebSocket server."""
if self.server:
self.server.close()
await self.server.wait_closed()
logger.info("Fake WebSocket server stopped")
async def handle_client(self, websocket):
"""Handle a new WebSocket client connection."""
self.connected_clients.add(websocket)
logger.info(f"Client connected from {websocket.remote_address}")
try:
async for message in websocket:
await self.handle_message(websocket, message)
@ -90,49 +82,45 @@ class FakeWebSocketServer:
logger.error(f"Error handling client: {e}")
finally:
self.connected_clients.discard(websocket)
async def handle_message(self, websocket, message_text: str):
"""Handle an incoming message from a client."""
try:
# Parse the raw message
data = json.loads(message_text)
request = parse_message(data)
# Generate response based on request type
response = await self.generate_response(request)
# Send response back to client
if response:
response_data = serialize_message(response)
await websocket.send(json.dumps(response_data))
logger.debug(f"Sent response: {response.name}")
except Exception as e:
logger.error(f"Error processing message: {e}")
logger.debug(f"Raw message: {message_text}")
# Send error response if we can extract request_id
try:
data = json.loads(message_text)
request_id = data.get("request_id", str(uuid.uuid4()))
error_response = ErrorResponse(
request_id=request_id,
error_type="PARSING_ERROR",
message=str(e)
)
error_response = ErrorResponse(request_id=request_id, error_type="PARSING_ERROR", message=str(e))
error_data = serialize_message(error_response)
await websocket.send(json.dumps(error_data))
except:
logger.error("Could not send error response")
async def generate_response(self, request) -> Optional[Any]:
"""Generate appropriate response for a request."""
logger.info(f"Processing request: {request.name}")
# Authentication requests
if isinstance(request, SignInRequest):
return await self.handle_sign_in(request)
# Channel-level requests (require token)
elif isinstance(request, CreateGameRequest):
return await self.handle_create_game(request)
@ -140,7 +128,7 @@ class FakeWebSocketServer:
return await self.handle_join_game(request)
elif isinstance(request, ListGamesRequest):
return await self.handle_list_games(request)
# Game-level requests (require token + game context)
elif isinstance(request, SetOrdersRequest):
return await self.handle_set_orders(request)
@ -148,45 +136,36 @@ class FakeWebSocketServer:
return await self.handle_process_game(request)
elif isinstance(request, GetAllPossibleOrdersRequest):
return await self.handle_get_possible_orders(request)
else:
logger.warning(f"Unhandled request type: {type(request)}")
return ErrorResponse(
request_id=request.request_id,
error_type="UNSUPPORTED_REQUEST",
message=f"Request type {request.name} not supported by fake server"
request_id=request.request_id, error_type="UNSUPPORTED_REQUEST", message=f"Request type {request.name} not supported by fake server"
)
async def handle_sign_in(self, request: SignInRequest) -> Any:
"""Handle authentication request."""
username = request.username
password = request.password
if username in self.users and self.users[username] == password:
# Generate auth token
token = f"fake_token_{uuid.uuid4().hex[:16]}"
self.tokens[token] = username
return DataTokenResponse(
request_id=request.request_id,
data=token
)
return DataTokenResponse(request_id=request.request_id, data=token)
else:
return ErrorResponse(
request_id=request.request_id,
error_type="AUTHENTICATION_ERROR",
message="Invalid username or password"
)
return ErrorResponse(request_id=request.request_id, error_type="AUTHENTICATION_ERROR", message="Invalid username or password")
async def handle_create_game(self, request: CreateGameRequest) -> Any:
"""Handle game creation request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Create a new game
game_id = f"GAME_{self.game_counter:04d}"
self.game_counter += 1
game_data = {
"game_id": game_id,
"map_name": request.map_name,
@ -203,102 +182,83 @@ class FakeWebSocketServer:
"RUSSIA": {"units": [], "centers": ["MOS", "SEV", "STP", "WAR"], "is_eliminated": False},
"TURKEY": {"units": [], "centers": ["ANK", "CON", "SMY"], "is_eliminated": False},
},
"controlled_powers": {}
"controlled_powers": {},
}
# If a specific power was requested, assign it
if request.power_name:
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
self.games[game_id] = game_data
return DataGameResponse(
request_id=request.request_id,
data=game_data
)
return DataGameResponse(request_id=request.request_id, data=game_data)
async def handle_join_game(self, request: JoinGameRequest) -> Any:
"""Handle game join request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
game_data = self.games[game_id].copy()
# If a specific power was requested, assign it
if request.power_name:
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
self.games[game_id] = game_data
return DataGameResponse(
request_id=request.request_id,
data=game_data
)
return DataGameResponse(request_id=request.request_id, data=game_data)
async def handle_list_games(self, request: ListGamesRequest) -> Any:
"""Handle list games request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Return simplified game info
games_list = []
for game_id, game_data in self.games.items():
games_list.append({
"game_id": game_id,
"map_name": game_data["map_name"],
"status": game_data["status"],
"phase": game_data["phase"],
"n_controls": game_data["n_controls"]
})
return DataGamesResponse(
request_id=request.request_id,
data=games_list
)
games_list.append(
{
"game_id": game_id,
"map_name": game_data["map_name"],
"status": game_data["status"],
"phase": game_data["phase"],
"n_controls": game_data["n_controls"],
}
)
return DataGamesResponse(request_id=request.request_id, data=games_list)
async def handle_set_orders(self, request: SetOrdersRequest) -> Any:
"""Handle set orders request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
# In a real server, we'd validate the orders and store them
# For testing, we just acknowledge receipt
logger.info(f"Orders received for {request.game_role}: {request.orders}")
return OkResponse(request_id=request.request_id)
async def handle_process_game(self, request: ProcessGameRequest) -> Any:
"""Handle process game request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
# Simulate game processing
game_data = self.games[game_id]
current_phase = game_data["phase"]
# Simple phase progression
if current_phase == "S1901M":
game_data["phase"] = "F1901M"
@ -307,55 +267,48 @@ class FakeWebSocketServer:
else:
# For testing, cycle back to start
game_data["phase"] = "S1902M"
self.games[game_id] = game_data
# Send notification to all clients (in real implementation)
# For testing, we'll just return OK
return OkResponse(request_id=request.request_id)
async def handle_get_possible_orders(self, request: GetAllPossibleOrdersRequest) -> Any:
"""Handle get possible orders request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Return mock possible orders
possible_orders = {
"PAR": ["A PAR H", "A PAR - BUR", "A PAR - PIC", "A PAR - GAS"],
"BRE": ["F BRE H", "F BRE - MAO", "F BRE - ENG", "F BRE - PIC"],
"MAR": ["A MAR H", "A MAR - GAS", "A MAR - SPA", "A MAR - PIE"],
}
return DataPossibleOrdersResponse(
request_id=request.request_id,
data=possible_orders
)
return DataPossibleOrdersResponse(request_id=request.request_id, data=possible_orders)
def validate_token(self, token: str) -> bool:
"""Validate an authentication token."""
return token in self.tokens
def create_auth_error(self, request_id: str) -> ErrorResponse:
"""Create a standard authentication error response."""
return ErrorResponse(
request_id=request_id,
error_type="AUTHENTICATION_ERROR",
message="Invalid or missing authentication token"
)
return ErrorResponse(request_id=request_id, error_type="AUTHENTICATION_ERROR", message="Invalid or missing authentication token")
class FakeServerManager:
"""Context manager for the fake WebSocket server."""
def __init__(self, host: str = "localhost", port: int = 8433):
self.server = FakeWebSocketServer(host, port)
async def __aenter__(self):
await self.server.start()
# Give the server a moment to start
await asyncio.sleep(0.1)
return self.server
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.server.stop()
@ -375,4 +328,4 @@ async def run_fake_server(host: str = "localhost", port: int = 8433):
if __name__ == "__main__":
# Run the fake server for manual testing
asyncio.run(run_fake_server())
asyncio.run(run_fake_server())

View file

@ -0,0 +1,347 @@
"""
Integration tests for inter-power messaging functionality.
These tests verify that the enhanced WebSocket client can properly:
1. Send messages between powers
2. Retrieve and filter messages
3. Handle broadcast messaging
4. Track message patterns for strategic communication
The tests use the fake server infrastructure but test real message flow
without mocking the transport layer.
"""
import asyncio
import pytest
from typing import Dict, List
from unittest.mock import AsyncMock, MagicMock
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
from diplomacy.engine.message import Message, GLOBAL
class TestBasicMessaging:
"""Test core messaging functionality between powers."""
@pytest.fixture
async def two_clients(self, fake_server):
"""Fixture providing two authenticated clients in the same game."""
# Create two clients
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
client_germany = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
# Connect and authenticate both clients
await client_france.connect_and_authenticate("test_user", "test_password")
await client_germany.connect_and_authenticate("ai_player", "password")
# Create a game with first client as FRANCE
await client_france.create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="FRANCE",
n_controls=2, # Only need 2 powers for testing
)
# Second client joins as GERMANY
await client_germany.join_game(game_id=client_france.game_id, power_name="GERMANY")
# Synchronize both clients
await client_france.synchronize()
await client_germany.synchronize()
yield {"france": client_france, "germany": client_germany}
finally:
# Cleanup
try:
await client_france.close()
await client_germany.close()
except:
pass
async def test_direct_message_sending(self, two_clients):
"""Test sending a direct message between two powers."""
france_client = two_clients["france"]
germany_client = two_clients["germany"]
# Send message from FRANCE to GERMANY
test_message = "Hello Germany, shall we form an alliance?"
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
# Allow message to propagate
await asyncio.sleep(0.1)
# Synchronize both clients to get latest messages
await france_client.synchronize()
await germany_client.synchronize()
# Check that Germany received the message
germany_messages = await germany_client.get_recent_messages(limit=10)
# Find our message
sent_message = None
for msg in germany_messages:
if msg.sender == "FRANCE" and msg.recipient == "GERMANY":
sent_message = msg
break
assert sent_message is not None, "Message was not received by Germany"
assert sent_message.message == test_message
assert sent_message.sender == "FRANCE"
assert sent_message.recipient == "GERMANY"
async def test_global_message_broadcasting(self, two_clients):
"""Test broadcasting a global message."""
france_client = two_clients["france"]
germany_client = two_clients["germany"]
# Send global message from FRANCE
test_message = "Greetings to all powers! Let's have a good game."
await france_client.send_message(sender="FRANCE", recipient=GLOBAL, message=test_message)
# Allow message to propagate
await asyncio.sleep(0.1)
# Synchronize both clients
await france_client.synchronize()
await germany_client.synchronize()
# Check that both clients can see the global message
france_messages = await france_client.get_recent_messages(limit=10)
germany_messages = await germany_client.get_recent_messages(limit=10)
# Find the global message in both clients
def find_global_message(messages):
for msg in messages:
if msg.sender == "FRANCE" and msg.recipient == GLOBAL:
return msg
return None
france_msg = find_global_message(france_messages)
germany_msg = find_global_message(germany_messages)
assert france_msg is not None, "France should see its own global message"
assert germany_msg is not None, "Germany should see France's global message"
assert france_msg.message == test_message
assert germany_msg.message == test_message
async def test_broadcast_to_all_active_powers(self, two_clients):
"""Test the enhanced broadcast functionality."""
france_client = two_clients["france"]
germany_client = two_clients["germany"]
# Send broadcast message to all active powers
test_message = "This is a broadcast to all active powers."
await france_client.send_broadcast_message(sender="FRANCE", message=test_message)
# Allow messages to propagate
await asyncio.sleep(0.1)
# Synchronize clients
await france_client.synchronize()
await germany_client.synchronize()
# Check that Germany received the direct message from broadcast
germany_messages = await germany_client.get_recent_messages(limit=10)
broadcast_message = None
for msg in germany_messages:
if msg.sender == "FRANCE" and msg.recipient == "GERMANY" and msg.message == test_message:
broadcast_message = msg
break
assert broadcast_message is not None, "Germany should receive broadcast message"
assert broadcast_message.message == test_message
async def test_message_filtering_by_phase(self, two_clients):
"""Test message retrieval filtered by game phase."""
france_client = two_clients["france"]
# Get current phase
current_phase = france_client.get_current_short_phase()
# Send a message
test_message = "Phase-specific message"
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
# Allow message to propagate
await asyncio.sleep(0.1)
await france_client.synchronize()
# Get messages for current phase
phase_messages = await france_client.get_recent_messages(phase=current_phase, limit=10)
# Find our message
found_message = None
for msg in phase_messages:
if msg.message == test_message:
found_message = msg
break
assert found_message is not None, "Message should be found in current phase"
assert found_message.phase == current_phase
class TestMessageHistory:
"""Test message history and retrieval functionality."""
@pytest.fixture
async def client_with_messages(self, fake_server):
"""Fixture providing a client with some test messages."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("test_user", "test_password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="FRANCE",
n_controls=1,
)
await client.synchronize()
# Send several test messages
test_messages = [
("Hello world!", GLOBAL),
("Private message to England", "ENGLAND"),
("Another global message", GLOBAL),
("Direct to Germany", "GERMANY"),
]
for message, recipient in test_messages:
await client.send_message(sender="FRANCE", recipient=recipient, message=message)
await asyncio.sleep(0.05) # Small delay between messages
await client.synchronize()
yield client
finally:
try:
await client.close()
except:
pass
async def test_get_recent_messages_limit(self, client_with_messages):
"""Test that message limit parameter works correctly."""
client = client_with_messages
# Get recent messages with different limits
messages_3 = await client.get_recent_messages(limit=3)
messages_2 = await client.get_recent_messages(limit=2)
messages_1 = await client.get_recent_messages(limit=1)
assert len(messages_3) <= 3
assert len(messages_2) <= 2
assert len(messages_1) <= 1
# Messages should be in reverse chronological order (most recent first)
if len(messages_3) > 1:
# Check that timestamps are in descending order
for i in range(len(messages_3) - 1):
msg1_time = messages_3[i].time_sent or 0
msg2_time = messages_3[i + 1].time_sent or 0
assert msg1_time >= msg2_time, "Messages should be in reverse chronological order"
async def test_message_retrieval_by_sender(self, client_with_messages):
"""Test filtering messages by sender."""
client = client_with_messages
# Get all recent messages
all_messages = await client.get_recent_messages(limit=20)
# Filter messages from FRANCE
france_messages = [msg for msg in all_messages if msg.sender == "FRANCE"]
# We should have the 4 messages we sent as FRANCE
assert len(france_messages) >= 4, f"Expected at least 4 FRANCE messages, got {len(france_messages)}"
# All messages should be from FRANCE
for msg in france_messages:
assert msg.sender == "FRANCE"
async def test_empty_message_history(self, fake_server):
"""Test behavior when no messages exist."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("ai_player", "password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS"],
power_name="AUSTRIA",
n_controls=1,
)
await client.synchronize()
# Get messages when none exist
messages = await client.get_recent_messages(limit=10)
# Should return empty list, not None or error
assert isinstance(messages, list)
assert len(messages) == 0
finally:
try:
await client.close()
except:
pass
class TestMessageErrorHandling:
"""Test error handling in messaging functionality."""
@pytest.fixture
async def client(self, fake_server):
"""Basic authenticated client for error testing."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("player1", "password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS"],
power_name="ITALY",
n_controls=1,
)
await client.synchronize()
yield client
finally:
try:
await client.close()
except:
pass
async def test_send_message_to_invalid_recipient(self, client):
"""Test sending message to non-existent power."""
# This should not raise an exception - the server/game should handle invalid recipients
try:
await client.send_message(sender="ITALY", recipient="INVALID_POWER", message="This should not crash")
# If we get here, the call succeeded (which is fine)
except Exception as e:
# If an exception is raised, it should be a specific diplomacy exception, not a crash
assert "INVALID_POWER" in str(e) or "recipient" in str(e).lower()
async def test_send_empty_message(self, client):
"""Test sending empty message."""
# Empty messages should be handled gracefully
await client.send_message(sender="ITALY", recipient="FRANCE", message="")
# If we get here without exception, the empty message was handled properly
async def test_get_messages_before_game_setup(self, fake_server):
"""Test getting messages when no game is joined."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("test_user", "test_password")
# Try to get messages without joining a game
with pytest.raises(Exception): # Should raise some form of exception
await client.get_recent_messages()
finally:
try:
await client.close()
except:
pass

View file

@ -0,0 +1,356 @@
"""
Unit tests for the new inter-power messaging functionality.
These tests focus on testing the specific functionality we added:
1. Enhanced WebSocketDiplomacyClient methods
2. Message tracking and statistics in SingleBotPlayer
3. Negotiation targeting logic
These are more focused unit tests that don't require full server integration.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from typing import Dict, List
from websocket_diplomacy_client import WebSocketDiplomacyClient
from single_bot_player import SingleBotPlayer
from websocket_negotiations import (
analyze_recent_messages_for_targeting,
should_participate_in_negotiations,
get_negotiation_delay,
)
from diplomacy.engine.message import Message, GLOBAL
class TestWebSocketClientEnhancements:
"""Test the enhanced methods we added to WebSocketDiplomacyClient."""
@pytest.fixture
def mock_client(self):
"""Create a mocked WebSocketDiplomacyClient for testing."""
client = WebSocketDiplomacyClient()
# Mock the game and related objects
client.game = MagicMock()
mock_powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: True), # Eliminated
}
client.game.powers = mock_powers
# Mock message history
mock_messages = [
Message(sender="ENGLAND", recipient="FRANCE", message="Hello France!", phase="S1901M"),
Message(sender="GERMANY", recipient="FRANCE", message="Greetings!", phase="S1901M"),
Message(sender="FRANCE", recipient="GLOBAL", message="Hello everyone!", phase="S1901M"),
]
client.game.messages = {i: msg for i, msg in enumerate(mock_messages)}
return client
async def test_send_broadcast_message(self, mock_client):
"""Test the new send_broadcast_message method."""
# Mock the send_message method and powers property
mock_client.send_message = AsyncMock()
# Mock the powers property to return our mock powers
with patch.object(type(mock_client), "powers", new_callable=lambda: property(lambda self: self.game.powers)):
# Send broadcast message
await mock_client.send_broadcast_message(sender="FRANCE", message="Hello to all active powers!")
# Should have called send_message for each active power (excluding sender and eliminated)
expected_calls = 2 # ENGLAND and GERMANY (not ITALY because eliminated, not FRANCE because sender)
assert mock_client.send_message.call_count == expected_calls
# Check that messages were sent to the right powers
call_args_list = mock_client.send_message.call_args_list
recipients = [call[1]["recipient"] for call in call_args_list] # Get recipient from kwargs
assert "ENGLAND" in recipients
assert "GERMANY" in recipients
assert "FRANCE" not in recipients # Shouldn't send to self
assert "ITALY" not in recipients # Shouldn't send to eliminated power
async def test_get_recent_messages_filtering(self, mock_client):
"""Test the new get_recent_messages method with filtering."""
# Mock current phase
mock_client.get_current_short_phase = MagicMock(return_value="S1901M")
# Mock the messages property to return our mock messages
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
# Test getting recent messages
recent_messages = await mock_client.get_recent_messages(limit=5)
# Should return the mocked messages
assert len(recent_messages) <= 5
assert len(recent_messages) == 3 # We have 3 mock messages
# Test phase filtering
phase_messages = await mock_client.get_recent_messages(phase="S1901M", limit=10)
assert all(msg.phase == "S1901M" for msg in phase_messages)
async def test_get_recent_messages_limit(self, mock_client):
"""Test that the limit parameter works correctly."""
# Mock the messages property to return our mock messages
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
# Test with limit smaller than available messages
limited_messages = await mock_client.get_recent_messages(limit=2)
assert len(limited_messages) <= 2
# Test with limit larger than available messages
all_messages = await mock_client.get_recent_messages(limit=100)
assert len(all_messages) == 3 # Should not exceed available messages
class TestSingleBotPlayerMessageTracking:
"""Test the message tracking functionality in SingleBotPlayer."""
@pytest.fixture
def mock_bot_player(self):
"""Create a mock SingleBotPlayer for testing."""
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="FRANCE", model_name="test_model")
# Mock dependencies
bot.client = MagicMock()
bot.agent = MagicMock()
bot.game_history = MagicMock()
# Set up powers
bot.client.powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: False),
}
return bot
def test_message_counting_initialization(self, mock_bot_player):
"""Test that message tracking starts in clean state."""
bot = mock_bot_player
assert isinstance(bot.message_counts, dict)
assert isinstance(bot.response_counts, dict)
assert isinstance(bot.priority_contacts, list)
assert len(bot.message_counts) == 0
assert len(bot.response_counts) == 0
assert len(bot.priority_contacts) == 0
def test_priority_contact_updates(self, mock_bot_player):
"""Test that priority contacts are updated correctly."""
bot = mock_bot_player
# Set up message counts
bot.message_counts = {
"ENGLAND": 5,
"GERMANY": 3,
"ITALY": 2,
"AUSTRIA": 1,
"RUSSIA": 4,
}
# Update priority contacts
bot._update_priority_contacts()
# Should have top 4 contacts in order of activity
assert len(bot.priority_contacts) == 4
assert bot.priority_contacts[0] == "ENGLAND" # Highest count (5)
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
def test_message_statistics_generation(self, mock_bot_player):
"""Test generation of message statistics."""
bot = mock_bot_player
bot.current_phase = "S1901M"
# Set up data
bot.message_counts = {"ENGLAND": 4, "GERMANY": 2}
bot.response_counts = {"ENGLAND": 3, "GERMANY": 1}
bot.priority_contacts = ["ENGLAND", "GERMANY"]
# Generate statistics
stats = bot.get_message_statistics()
# Check basic structure
assert stats["power_name"] == "FRANCE"
assert stats["current_phase"] == "S1901M"
assert stats["total_messages_received"] == 6 # 4 + 2
assert stats["total_responses_sent"] == 4 # 3 + 1
# Check response rates
assert stats["response_rate_by_power"]["ENGLAND"] == 0.75 # 3/4
assert stats["response_rate_by_power"]["GERMANY"] == 0.5 # 1/2
def test_response_decision_logic(self, mock_bot_player):
"""Test the enhanced response decision logic."""
bot = mock_bot_player
bot.priority_contacts = ["ENGLAND"]
# Create test messages
priority_message = MagicMock()
priority_message.sender = "ENGLAND"
priority_message.recipient = "FRANCE"
priority_message.message = "Hello France!"
non_priority_message = MagicMock()
non_priority_message.sender = "GERMANY"
non_priority_message.recipient = "FRANCE"
non_priority_message.message = "Hello France!"
# Test the decision logic (extracted from _consider_message_response)
def should_respond(message):
message_lower = message.message.lower()
strategic_keywords = ["alliance", "deal", "propose", "agreement"]
return any(
[
"?" in message.message,
any(word in message_lower for word in ["hello", "hi", "greetings"]),
any(keyword in message_lower for keyword in strategic_keywords),
len(message.message.split()) > 15,
message.sender in bot.priority_contacts,
]
)
# Both should respond due to "hello", but priority logic is tested
assert should_respond(priority_message) is True
assert should_respond(non_priority_message) is True
# Test priority contact influence
assert priority_message.sender in bot.priority_contacts
assert non_priority_message.sender not in bot.priority_contacts
class TestNegotiationTargeting:
"""Test the strategic negotiation targeting logic."""
@pytest.fixture
def mock_client_with_messages(self):
"""Create a mock client with message history."""
client = MagicMock()
# Mock recent messages
mock_messages = [
Message(sender="ENGLAND", recipient="FRANCE", message="Direct to France", phase="S1901M"),
Message(sender="ENGLAND", recipient="FRANCE", message="Another to France", phase="S1901M"),
Message(sender="GERMANY", recipient="FRANCE", message="Message to France", phase="S1901M"),
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"),
]
client.get_recent_messages = AsyncMock(return_value=mock_messages)
client.powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: False),
}
return client
async def test_analyze_recent_messages_for_targeting(self, mock_client_with_messages):
"""Test the message analysis for targeting."""
client = mock_client_with_messages
# Analyze targeting for FRANCE
targets = await analyze_recent_messages_for_targeting(client, "FRANCE", max_messages=20)
# Should return a list of powers
assert isinstance(targets, list)
assert "FRANCE" not in targets # Should not include self
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
# GERMANY should be second (sent 1 direct message to FRANCE)
if len(targets) >= 2:
assert targets[0] == "ENGLAND" # Most direct messages
if "GERMANY" in targets:
germany_index = targets.index("GERMANY")
england_index = targets.index("ENGLAND")
assert england_index < germany_index # England should come before Germany
def test_negotiation_delay_calculation(self):
"""Test that negotiation delays are calculated correctly."""
# Test different round scenarios
first_delay = get_negotiation_delay(round_number=1, total_rounds=3)
middle_delay = get_negotiation_delay(round_number=2, total_rounds=3)
final_delay = get_negotiation_delay(round_number=3, total_rounds=3)
# First round should have longer delay
assert first_delay > middle_delay
# Final round should have shorter delay
assert final_delay < middle_delay
# All delays should be positive
assert all(delay > 0 for delay in [first_delay, middle_delay, final_delay])
async def test_should_participate_in_negotiations(self):
"""Test negotiation participation logic."""
# Mock client and agent
mock_client = MagicMock()
mock_agent = MagicMock()
mock_agent.power_name = "FRANCE"
# Test case: active power in movement phase
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_client.get_power.return_value = mock_power
mock_client.get_current_short_phase.return_value = "S1901M"
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_orders.return_value = ["A Paris - Hold"] # Has orders
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is True
# Test case: eliminated power
mock_power.is_eliminated.return_value = True
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is False
# Test case: non-movement phase
mock_power.is_eliminated.return_value = False
mock_client.get_current_short_phase.return_value = "S1901R" # Retreat phase
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_orders.return_value = ["A Paris - Hold"]
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is False
class TestIntegrationScenarios:
"""Test integration scenarios with mocked components."""
def test_message_persistence_across_phases(self):
"""Test that message counts persist across multiple game phases."""
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="AUSTRIA", model_name="test_model")
# Mock dependencies
bot.client = MagicMock()
bot.agent = MagicMock()
bot.game_history = MagicMock()
# Simulate message accumulation over phases
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
for phase in phases:
bot.current_phase = phase
# Simulate receiving messages
if phase.endswith("M"): # Movement phases
for sender in ["FRANCE", "ENGLAND"]:
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
else:
# Fewer messages in other phases
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
bot._update_priority_contacts()
# Check accumulated counts
assert bot.message_counts["FRANCE"] >= 5 # Should have accumulated messages
assert bot.message_counts["ENGLAND"] >= 4
assert len(bot.priority_contacts) > 0
assert bot.priority_contacts[0] == "FRANCE" # Should be top priority

View file

@ -0,0 +1,403 @@
"""
Integration tests for message persistence and tracking functionality.
These tests verify that the SingleBotPlayer can properly:
1. Track message counts and response patterns
2. Update priority contacts based on messaging activity
3. Generate accurate message statistics
4. Log statistics during phase transitions
5. Maintain message persistence across game phases
Tests focus on the tracking and statistics features rather than
the actual AI message generation.
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Dict
from single_bot_player import SingleBotPlayer
from websocket_diplomacy_client import WebSocketDiplomacyClient
from diplomacy.engine.message import Message
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.game_history import GameHistory
class TestMessageCounting:
"""Test message counting and tracking functionality."""
@pytest.fixture
def mock_bot_player(self):
"""Create a SingleBotPlayer with mocked dependencies for testing."""
# Create mock bot player
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="ENGLAND", model_name="test_model", game_id="test_game")
# Mock the client and agent
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
bot.agent = MagicMock(spec=DiplomacyAgent)
bot.game_history = GameHistory()
return bot
def test_initial_message_tracking_state(self, mock_bot_player):
"""Test that message tracking starts in clean state."""
bot = mock_bot_player
assert isinstance(bot.message_counts, dict)
assert isinstance(bot.response_counts, dict)
assert isinstance(bot.priority_contacts, list)
assert len(bot.message_counts) == 0
assert len(bot.response_counts) == 0
assert len(bot.priority_contacts) == 0
def test_message_count_tracking(self, mock_bot_player):
"""Test that incoming messages are counted correctly."""
bot = mock_bot_player
# Simulate receiving messages from different powers
test_messages = [
Message(sender="FRANCE", recipient="ENGLAND", message="Hello England!", phase="S1901M"),
Message(sender="FRANCE", recipient="ENGLAND", message="Another message", phase="S1901M"),
Message(sender="GERMANY", recipient="ENGLAND", message="Greetings!", phase="S1901M"),
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"), # Should not be counted
Message(sender="ENGLAND", recipient="FRANCE", message="Self message", phase="S1901M"), # Should not be counted
]
# Process each message through the message handler
for msg in test_messages:
# Simulate the message handling logic from _on_message_received
bot.game_history.add_message(
phase_name=msg.phase,
sender=msg.sender,
recipient=msg.recipient,
message_content=msg.message,
)
# Track message patterns (only for messages TO this bot)
if msg.recipient == bot.power_name and msg.sender != bot.power_name:
bot.message_counts[msg.sender] = bot.message_counts.get(msg.sender, 0) + 1
bot._update_priority_contacts()
# Check counts
assert bot.message_counts["FRANCE"] == 2
assert bot.message_counts["GERMANY"] == 1
assert "ITALY" not in bot.message_counts # Global message not counted
assert "ENGLAND" not in bot.message_counts # Self message not counted
def test_response_count_tracking(self, mock_bot_player):
"""Test that outgoing responses are counted correctly."""
bot = mock_bot_player
# Simulate sending responses to different powers
responses = [
("FRANCE", "Thanks for your message!"),
("FRANCE", "Another response to France"),
("GERMANY", "Hello Germany"),
]
for recipient, message in responses:
# Simulate response sending logic
bot.response_counts[recipient] = bot.response_counts.get(recipient, 0) + 1
# Check response counts
assert bot.response_counts["FRANCE"] == 2
assert bot.response_counts["GERMANY"] == 1
def test_priority_contact_updates(self, mock_bot_player):
"""Test that priority contacts are updated based on message counts."""
bot = mock_bot_player
# Set up message counts with different activity levels
bot.message_counts = {
"FRANCE": 5, # Most active
"GERMANY": 3, # Second most active
"ITALY": 2, # Third most active
"AUSTRIA": 1, # Least active
"RUSSIA": 4, # Second highest
}
# Update priority contacts
bot._update_priority_contacts()
# Should have top 4 contacts in order of activity
assert len(bot.priority_contacts) == 4
assert bot.priority_contacts[0] == "FRANCE" # Highest count (5)
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
# AUSTRIA should not be in top 4
def test_priority_contacts_with_fewer_powers(self, mock_bot_player):
"""Test priority contacts when fewer than 4 powers are active."""
bot = mock_bot_player
# Set up message counts with only 2 powers
bot.message_counts = {
"FRANCE": 3,
"GERMANY": 1,
}
bot._update_priority_contacts()
# Should have only 2 contacts
assert len(bot.priority_contacts) == 2
assert bot.priority_contacts[0] == "FRANCE"
assert bot.priority_contacts[1] == "GERMANY"
class TestMessageStatistics:
"""Test message statistics generation and reporting."""
@pytest.fixture
def bot_with_message_data(self, mock_bot_player):
"""Bot player with pre-populated message tracking data."""
bot = mock_bot_player
# Set up realistic message and response data
bot.message_counts = {
"FRANCE": 8,
"GERMANY": 5,
"ITALY": 3,
"RUSSIA": 2,
}
bot.response_counts = {
"FRANCE": 6, # 75% response rate
"GERMANY": 2, # 40% response rate
"ITALY": 3, # 100% response rate
"RUSSIA": 0, # 0% response rate
}
bot.priority_contacts = ["FRANCE", "GERMANY", "ITALY", "RUSSIA"]
bot.current_phase = "S1901M"
# Mock the client's powers for active power detection
mock_powers = {
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"FRANCE": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: False),
"RUSSIA": MagicMock(is_eliminated=lambda: False),
"AUSTRIA": MagicMock(is_eliminated=lambda: True), # Eliminated
"TURKEY": MagicMock(is_eliminated=lambda: False),
}
bot.client.powers = mock_powers
return bot
def test_message_statistics_generation(self, bot_with_message_data):
"""Test that message statistics are generated correctly."""
bot = bot_with_message_data
stats = bot.get_message_statistics()
# Check basic structure
assert stats["power_name"] == "ENGLAND"
assert stats["current_phase"] == "S1901M"
assert isinstance(stats["message_counts_by_power"], dict)
assert isinstance(stats["response_counts_by_power"], dict)
assert isinstance(stats["response_rate_by_power"], dict)
assert isinstance(stats["priority_contacts"], list)
assert isinstance(stats["active_powers"], list)
# Check calculated values
assert stats["total_messages_received"] == 18 # 8+5+3+2
assert stats["total_responses_sent"] == 11 # 6+2+3+0
# Check response rates
assert stats["response_rate_by_power"]["FRANCE"] == 0.75 # 6/8
assert stats["response_rate_by_power"]["GERMANY"] == 0.4 # 2/5
assert stats["response_rate_by_power"]["ITALY"] == 1.0 # 3/3
assert stats["response_rate_by_power"]["RUSSIA"] == 0.0 # 0/2
# Check active powers (should exclude eliminated AUSTRIA and self)
active_powers = stats["active_powers"]
assert "AUSTRIA" not in active_powers # Eliminated
assert "ENGLAND" not in active_powers # Self
assert "FRANCE" in active_powers
assert "TURKEY" in active_powers
def test_message_statistics_empty_data(self, mock_bot_player):
"""Test statistics generation with no message data."""
bot = mock_bot_player
bot.current_phase = "S1901M"
# Mock empty powers
mock_powers = {
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"FRANCE": MagicMock(is_eliminated=lambda: False),
}
bot.client.powers = mock_powers
stats = bot.get_message_statistics()
assert stats["total_messages_received"] == 0
assert stats["total_responses_sent"] == 0
assert len(stats["message_counts_by_power"]) == 0
assert len(stats["response_counts_by_power"]) == 0
assert len(stats["response_rate_by_power"]) == 0
assert len(stats["priority_contacts"]) == 0
def test_message_statistics_logging(self, bot_with_message_data, caplog):
"""Test that message statistics are logged correctly."""
bot = bot_with_message_data
# Call the logging method
bot.log_message_statistics()
# Check that appropriate log messages were generated
log_output = caplog.text
assert "Message Statistics for ENGLAND" in log_output
assert "Total messages received: 18" in log_output
assert "Total responses sent: 11" in log_output
assert "Priority contacts:" in log_output
# Check that individual power stats are logged
assert "FRANCE:" in log_output
assert "GERMANY:" in log_output
assert "75%" in log_output # France response rate
assert "40%" in log_output # Germany response rate
class TestMessagePersistenceIntegration:
"""Test integration of message persistence with game flow."""
@pytest.fixture
async def bot_integration_setup(self, fake_server):
"""Setup for integration testing with fake server."""
# Note: This is a more complex fixture that would require actual
# SingleBotPlayer initialization, which depends on AI client setup
# For now, we'll focus on the core tracking logic
bot = SingleBotPlayer(username="integration_test", password="test_pass", power_name="ITALY", model_name="test_model", game_id=None)
# Mock the complex dependencies
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
bot.agent = MagicMock(spec=DiplomacyAgent)
bot.game_history = GameHistory()
return bot
def test_message_tracking_during_phase_transition(self, bot_integration_setup):
"""Test that message statistics are logged during phase transitions."""
bot = bot_integration_setup
# Set up some message data
bot.message_counts = {"FRANCE": 2, "GERMANY": 1}
bot.current_phase = "S1901M"
# Mock the client powers for log_message_statistics
mock_powers = {
"ITALY": MagicMock(is_eliminated=lambda: False),
"FRANCE": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
}
bot.client.powers = mock_powers
# Mock the log_message_statistics method to track if it's called
with patch.object(bot, "log_message_statistics") as mock_log:
# Simulate phase transition logic
new_phase = "F1901M"
if new_phase != bot.current_phase:
bot.current_phase = new_phase
bot.game_history.add_phase(new_phase)
bot.orders_submitted = False
bot.current_negotiation_round = 0
bot.negotiation_complete = False
# This is the key logic from _handle_phase_update_async
if hasattr(bot, "message_counts") and bot.message_counts:
bot.log_message_statistics()
# Verify that statistics were logged
mock_log.assert_called_once()
def test_priority_contact_influence_on_response_decisions(self, bot_integration_setup):
"""Test that priority contacts influence message response decisions."""
bot = bot_integration_setup
# Set up priority contacts
bot.priority_contacts = ["FRANCE", "GERMANY"]
# Create test messages from different senders
priority_message = Message(sender="FRANCE", recipient="ITALY", message="Hello Italy!", phase="S1901M")
non_priority_message = Message(sender="AUSTRIA", recipient="ITALY", message="Hello Italy!", phase="S1901M")
# Test the response decision logic (from _consider_message_response)
def should_respond_to_message(message):
message_lower = message.message.lower()
strategic_keywords = [
"alliance",
"deal",
"propose",
"agreement",
"support",
"attack",
"coordinate",
"move",
"order",
"help",
"work together",
"partner",
"enemy",
"threat",
"negotiate",
"discuss",
"plan",
"strategy",
"bounce",
"convoy",
"retreat",
]
return any(
[
"?" in message.message, # Questions
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
len(message.message.split()) > 15, # Longer messages
message.sender in bot.priority_contacts, # Priority contacts
]
)
# Priority contact should be more likely to get response
priority_should_respond = should_respond_to_message(priority_message)
non_priority_should_respond = should_respond_to_message(non_priority_message)
# Both should respond due to "hello" keyword, but priority contact logic is working
assert priority_should_respond is True
# Non-priority should also respond due to "hello", but test the contact logic
assert priority_message.sender in bot.priority_contacts
assert non_priority_message.sender not in bot.priority_contacts
def test_message_persistence_across_multiple_phases(self, bot_integration_setup):
"""Test that message counts persist across multiple game phases."""
bot = bot_integration_setup
# Simulate message accumulation over multiple phases
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
for phase in phases:
bot.current_phase = phase
# Simulate receiving messages in each phase
if phase.endswith("M"): # Movement phases
# More messages during movement phases
for sender in ["FRANCE", "GERMANY"]:
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
else:
# Fewer messages during other phases
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
bot._update_priority_contacts()
# After all phases, check accumulated counts
assert bot.message_counts["FRANCE"] >= 5 # 2+2+1+2 from movement phases + 1 from adjustment
assert bot.message_counts["GERMANY"] >= 4 # 2+2+2 from movement phases only
# France should be top priority due to higher count
assert len(bot.priority_contacts) > 0
assert bot.priority_contacts[0] == "FRANCE"

View file

@ -0,0 +1,440 @@
"""
Integration tests for multi-bot messaging scenarios.
These tests verify realistic inter-power communication scenarios with
multiple bots interacting simultaneously. Tests focus on:
1. Multi-bot conversation flows
2. Negotiation coordination between multiple powers
3. Message response patterns in realistic scenarios
4. Phase transition coordination across multiple bots
Tests use the fake server but simulate realistic multi-power interactions
without mocking the core message transport logic.
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Dict, List
from websocket_diplomacy_client import WebSocketDiplomacyClient
from single_bot_player import SingleBotPlayer
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.game_history import GameHistory
class TestTwoBotConversation:
"""Test conversation flow between two bots."""
@pytest.fixture
async def two_bot_setup(self, fake_server):
"""Setup two bots in the same game for conversation testing."""
# Create two WebSocket clients
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
client_england = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
# Connect both clients
await client_france.connect_and_authenticate("france_bot", "password")
await client_england.connect_and_authenticate("england_bot", "password")
# Create game with France
await client_france.create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="FRANCE",
n_controls=2,
)
# England joins the game
await client_england.join_game(game_id=client_france.game_id, power_name="ENGLAND")
# Synchronize both clients
await client_france.synchronize()
await client_england.synchronize()
yield {"france_client": client_france, "england_client": client_england, "game_id": client_france.game_id}
finally:
try:
await client_france.close()
await client_england.close()
except:
pass
async def test_basic_two_bot_exchange(self, two_bot_setup):
"""Test basic message exchange between two bots."""
france_client = two_bot_setup["france_client"]
england_client = two_bot_setup["england_client"]
# France initiates conversation
await france_client.send_message(sender="FRANCE", recipient="ENGLAND", message="Hello England! Shall we discuss our border?")
await asyncio.sleep(0.1)
await england_client.synchronize()
# Check England received the message
england_messages = await england_client.get_recent_messages(limit=5)
france_message = None
for msg in england_messages:
if msg.sender == "FRANCE" and msg.recipient == "ENGLAND":
france_message = msg
break
assert france_message is not None
assert "border" in france_message.message
# England responds
await england_client.send_message(sender="ENGLAND", recipient="FRANCE", message="Indeed, France. I propose we coordinate our fleets.")
await asyncio.sleep(0.1)
await france_client.synchronize()
# Check France received the response
france_messages = await france_client.get_recent_messages(limit=5)
england_response = None
for msg in france_messages:
if msg.sender == "ENGLAND" and msg.recipient == "FRANCE":
england_response = msg
break
assert england_response is not None
assert "coordinate" in england_response.message
async def test_conversation_thread_tracking(self, two_bot_setup):
"""Test that conversation threads can be tracked across multiple exchanges."""
france_client = two_bot_setup["france_client"]
england_client = two_bot_setup["england_client"]
# Simulate a conversation thread
conversation = [
("FRANCE", "ENGLAND", "England, are you interested in an alliance?"),
("ENGLAND", "FRANCE", "Yes France, what do you propose?"),
("FRANCE", "ENGLAND", "Let's coordinate attacks on Germany."),
("ENGLAND", "FRANCE", "Agreed. I'll move my fleet to support you."),
]
for sender, recipient, message in conversation:
if sender == "FRANCE":
await france_client.send_message(sender, recipient, message)
else:
await england_client.send_message(sender, recipient, message)
await asyncio.sleep(0.05) # Small delay between messages
# Synchronize both clients
await france_client.synchronize()
await england_client.synchronize()
# Get conversation history from both perspectives
france_messages = await france_client.get_recent_messages(limit=10)
england_messages = await england_client.get_recent_messages(limit=10)
# Filter for conversation between France and England
def filter_conversation(messages):
return [msg for msg in messages if (msg.sender in ["FRANCE", "ENGLAND"] and msg.recipient in ["FRANCE", "ENGLAND"])]
france_conv = filter_conversation(france_messages)
england_conv = filter_conversation(england_messages)
# Both should see the same conversation
assert len(france_conv) >= 4
assert len(england_conv) >= 4
# Check that key terms from the conversation appear
all_messages_text = " ".join([msg.message for msg in france_conv])
assert "alliance" in all_messages_text
assert "Germany" in all_messages_text
assert "support" in all_messages_text
class TestThreeBotNegotiation:
"""Test more complex negotiations with three bots."""
@pytest.fixture
async def three_bot_setup(self, fake_server):
"""Setup three bots for complex negotiation testing."""
clients = {}
powers = ["FRANCE", "ENGLAND", "GERMANY"]
try:
# Create and connect three clients
for power in powers:
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
await client.connect_and_authenticate(f"{power.lower()}_bot", "password")
clients[power] = client
# Create game with France
await clients["FRANCE"].create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="FRANCE",
n_controls=3,
)
game_id = clients["FRANCE"].game_id
# Other powers join
await clients["ENGLAND"].join_game(game_id=game_id, power_name="ENGLAND")
await clients["GERMANY"].join_game(game_id=game_id, power_name="GERMANY")
# Synchronize all clients
for client in clients.values():
await client.synchronize()
yield {"clients": clients, "game_id": game_id}
finally:
for client in clients.values():
try:
await client.close()
except:
pass
async def test_three_way_alliance_negotiation(self, three_bot_setup):
"""Test alliance negotiation between three powers."""
clients = three_bot_setup["clients"]
# Simulate alliance negotiation sequence
negotiations = [
("FRANCE", "ENGLAND", "England, shall we form an alliance against Germany?"),
("ENGLAND", "FRANCE", "I'm interested. What are your terms?"),
("FRANCE", "GERMANY", "Germany, France and England are discussing cooperation."),
("GERMANY", "FRANCE", "I see. Perhaps we should talk as well."),
("GERMANY", "ENGLAND", "England, what is France offering you?"),
("ENGLAND", "GERMANY", "Germany, I think we should all work together."),
]
# Send all negotiation messages
for sender, recipient, message in negotiations:
await clients[sender].send_message(sender, recipient, message)
await asyncio.sleep(0.1) # Allow message processing
# Synchronize all clients
for client in clients.values():
await client.synchronize()
# Analyze message patterns from each perspective
for power, client in clients.items():
messages = await client.get_recent_messages(limit=20)
# Count messages involving this power
involving_power = [msg for msg in messages if msg.sender == power or msg.recipient == power]
# Each power should be involved in multiple messages
assert len(involving_power) >= 2, f"{power} should be involved in multiple messages"
# Check that all three powers have communicated
all_messages = await clients["FRANCE"].get_recent_messages(limit=20)
senders = set(msg.sender for msg in all_messages)
recipients = set(msg.recipient for msg in all_messages)
# All three powers should appear as senders
assert "FRANCE" in senders
assert "ENGLAND" in senders
assert "GERMANY" in senders
async def test_broadcast_with_private_follow_ups(self, three_bot_setup):
"""Test broadcast message followed by private conversations."""
clients = three_bot_setup["clients"]
# France sends a global announcement
await clients["FRANCE"].send_broadcast_message(sender="FRANCE", message="All powers: I propose we establish clear spheres of influence.")
await asyncio.sleep(0.2)
# Follow up with private messages to each power
private_messages = [
("FRANCE", "ENGLAND", "England, I suggest you focus on the seas."),
("FRANCE", "GERMANY", "Germany, the eastern approach might suit you."),
]
for sender, recipient, message in private_messages:
await clients[sender].send_message(sender, recipient, message)
await asyncio.sleep(0.1)
# Synchronize all clients
for client in clients.values():
await client.synchronize()
# Check that England and Germany received their specific messages
england_messages = await clients["ENGLAND"].get_recent_messages(limit=10)
germany_messages = await clients["GERMANY"].get_recent_messages(limit=10)
# England should see both broadcast (to all) and private message
england_private = [msg for msg in england_messages if msg.sender == "FRANCE" and msg.recipient == "ENGLAND"]
england_broadcast = [msg for msg in england_messages if msg.sender == "FRANCE" and "spheres of influence" in msg.message]
assert len(england_private) >= 1, "England should receive private message"
assert any("seas" in msg.message for msg in england_private), "England should get seas message"
# Germany should see broadcast and their private message
germany_private = [msg for msg in germany_messages if msg.sender == "FRANCE" and msg.recipient == "GERMANY"]
assert len(germany_private) >= 1, "Germany should receive private message"
assert any("eastern" in msg.message for msg in germany_private), "Germany should get eastern message"
class TestBotPlayerMessageIntegration:
"""Test SingleBotPlayer message handling in multi-bot scenarios."""
@pytest.fixture
def mock_bot_players(self):
"""Create multiple mock SingleBotPlayer instances."""
bots = {}
powers = ["FRANCE", "ENGLAND", "GERMANY"]
for power in powers:
bot = SingleBotPlayer(
username=f"{power.lower()}_bot", password="test_pass", power_name=power, model_name="test_model", game_id="test_game"
)
# Mock dependencies
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
bot.agent = MagicMock(spec=DiplomacyAgent)
bot.game_history = GameHistory()
# Set up power name correctly
bot.agent.power_name = power
bots[power] = bot
return bots
def test_multi_bot_priority_contact_evolution(self, mock_bot_players):
"""Test how priority contacts evolve in multi-bot scenarios."""
bots = mock_bot_players
# Simulate asymmetric communication patterns
# France talks to everyone, England focuses on France, Germany is quiet
# France receives messages from multiple powers
france_bot = bots["FRANCE"]
france_bot.message_counts = {
"ENGLAND": 5, # England is very active with France
"GERMANY": 2, # Germany occasionally talks to France
"ITALY": 1, # Italy sends one message
}
france_bot._update_priority_contacts()
# England receives mostly from France
england_bot = bots["ENGLAND"]
england_bot.message_counts = {
"FRANCE": 8, # France talks to England a lot
"GERMANY": 1, # Germany sends one message
}
england_bot._update_priority_contacts()
# Germany receives few messages
germany_bot = bots["GERMANY"]
germany_bot.message_counts = {
"FRANCE": 3, # Some communication with France
"ENGLAND": 1, # Minimal with England
}
germany_bot._update_priority_contacts()
# Check that priority contacts reflect communication patterns
assert france_bot.priority_contacts[0] == "ENGLAND" # Most active with France
assert england_bot.priority_contacts[0] == "FRANCE" # France is England's main contact
assert germany_bot.priority_contacts[0] == "FRANCE" # France is Germany's main contact
# Check list lengths
assert len(france_bot.priority_contacts) == 3 # Three powers contacted France
assert len(england_bot.priority_contacts) == 2 # Two powers contacted England
assert len(germany_bot.priority_contacts) == 2 # Two powers contacted Germany
def test_response_pattern_analysis(self, mock_bot_players):
"""Test analysis of response patterns across multiple bots."""
bots = mock_bot_players
# Set up different response patterns for each bot
# France: Responsive to everyone
france_bot = bots["FRANCE"]
france_bot.message_counts = {"ENGLAND": 4, "GERMANY": 3}
france_bot.response_counts = {"ENGLAND": 4, "GERMANY": 3} # 100% response rate
# England: Selective responder
england_bot = bots["ENGLAND"]
england_bot.message_counts = {"FRANCE": 6, "GERMANY": 2}
england_bot.response_counts = {"FRANCE": 6, "GERMANY": 0} # Only responds to France
# Germany: Poor responder
germany_bot = bots["GERMANY"]
germany_bot.message_counts = {"FRANCE": 5, "ENGLAND": 3}
germany_bot.response_counts = {"FRANCE": 2, "ENGLAND": 1} # Low response rates
# Mock client powers for statistics
for power, bot in bots.items():
mock_powers = {p: MagicMock(is_eliminated=lambda: False) for p in ["FRANCE", "ENGLAND", "GERMANY", "ITALY"]}
bot.client.powers = mock_powers
# Generate statistics for each bot
stats = {}
for power, bot in bots.items():
stats[power] = bot.get_message_statistics()
# Analyze response patterns
france_stats = stats["FRANCE"]
england_stats = stats["ENGLAND"]
germany_stats = stats["GERMANY"]
# France should have high overall response rate
france_avg_response = sum(france_stats["response_rate_by_power"].values()) / len(france_stats["response_rate_by_power"])
assert france_avg_response == 1.0 # Perfect responder
# England should have selective response pattern
assert england_stats["response_rate_by_power"]["FRANCE"] == 1.0 # Always responds to France
assert england_stats["response_rate_by_power"]["GERMANY"] == 0.0 # Never responds to Germany
# Germany should have low overall response rate
germany_avg_response = sum(germany_stats["response_rate_by_power"].values()) / len(germany_stats["response_rate_by_power"])
assert germany_avg_response < 0.5 # Poor overall response rate
def test_message_history_consistency_across_bots(self, mock_bot_players):
"""Test that message history tracking is consistent across different bots."""
bots = mock_bot_players
# Simulate the same set of messages being processed by different bots
# (as they would see them in a real game)
from diplomacy.engine.message import Message
shared_messages = [
Message(sender="FRANCE", recipient="ENGLAND", message="Alliance proposal", phase="S1901M"),
Message(sender="ENGLAND", recipient="FRANCE", message="I accept", phase="S1901M"),
Message(sender="GERMANY", recipient="FRANCE", message="What about me?", phase="S1901M"),
Message(sender="FRANCE", recipient="GERMANY", message="You're welcome too", phase="S1901M"),
]
# Each bot processes messages relevant to them
for message in shared_messages:
for power, bot in bots.items():
# Add message to game history (all bots see all messages)
bot.game_history.add_message(
phase_name=message.phase,
sender=message.sender,
recipient=message.recipient,
message_content=message.message,
)
# Track messages directed at this bot
if message.recipient == power and message.sender != power:
bot.message_counts[message.sender] = bot.message_counts.get(message.sender, 0) + 1
bot._update_priority_contacts()
# Verify that each bot has tracked messages correctly
# France should have received 2 messages (from England and Germany)
assert bots["FRANCE"].message_counts["ENGLAND"] == 1
assert bots["FRANCE"].message_counts["GERMANY"] == 1
# England should have received 1 message (from France)
assert bots["ENGLAND"].message_counts["FRANCE"] == 1
# Germany should have received 1 message (from France)
assert bots["GERMANY"].message_counts["FRANCE"] == 1
# Check that game history is consistent (all bots see all 4 messages)
for power, bot in bots.items():
all_messages = bot.game_history.get_messages_for_phase("S1901M")
assert len(all_messages) == 4, f"{power} should see all 4 messages in game history"

View file

@ -0,0 +1,433 @@
"""
Integration tests for negotiation round functionality.
These tests verify that the negotiation system can properly:
1. Conduct strategic negotiation rounds with message targeting
2. Analyze recent messages for targeting decisions
3. Handle negotiation timing and coordination
4. Determine negotiation participation appropriately
Tests use mocked AI responses for predictable behavior while testing
the real negotiation coordination logic.
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from typing import Dict, List
from websocket_diplomacy_client import WebSocketDiplomacyClient
from websocket_negotiations import (
conduct_strategic_negotiation_round,
analyze_recent_messages_for_targeting,
should_participate_in_negotiations,
get_negotiation_delay,
)
from ai_diplomacy.agent import DiplomacyAgent
from ai_diplomacy.game_history import GameHistory
class TestMessageTargeting:
"""Test strategic message targeting analysis."""
@pytest.fixture
async def client_with_messages(self, fake_server):
"""Client with pre-populated message history for targeting tests."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("targeting_user", "password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="FRANCE",
n_controls=7, # Full game for realistic targeting
)
await client.synchronize()
# Simulate message history with different activity levels
# Note: In a real scenario, these would come from other players
# For testing, we'll add them to the game's message history directly
yield client
finally:
try:
await client.close()
except:
pass
async def test_analyze_recent_messages_empty_history(self, client_with_messages):
"""Test targeting analysis with no message history."""
client = client_with_messages
# Get targeting priority with empty message history
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
# Should return all active powers (excluding FRANCE itself)
assert isinstance(targets, list)
# Should not include FRANCE itself
assert "FRANCE" not in targets
# Should include other major powers
expected_powers = {"ENGLAND", "GERMANY", "ITALY", "AUSTRIA", "RUSSIA", "TURKEY"}
assert len(set(targets) & expected_powers) > 0
async def test_analyze_recent_messages_with_activity(self, client_with_messages):
"""Test targeting analysis with simulated message activity."""
client = client_with_messages
# Simulate some message activity by sending messages
# This tests the actual message retrieval and analysis logic
# Send messages to create activity patterns
await client.send_message("ENGLAND", "FRANCE", "Hello France!")
await client.send_message("ENGLAND", "FRANCE", "Another message from England")
await client.send_message("GERMANY", "FRANCE", "Message from Germany")
await client.send_message("ITALY", "GLOBAL", "Global message from Italy")
await asyncio.sleep(0.1)
await client.synchronize()
# Now analyze targeting
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
# GERMANY should be second (sent 1 direct message to FRANCE)
# Others should follow
assert isinstance(targets, list)
assert "FRANCE" not in targets
# England should be first in priority (most direct messages to FRANCE)
if "ENGLAND" in targets:
england_index = targets.index("ENGLAND")
if "GERMANY" in targets:
germany_index = targets.index("GERMANY")
assert england_index < germany_index, "England should be prioritized over Germany"
async def test_analyze_messages_error_handling(self, client_with_messages):
"""Test error handling in message analysis."""
client = client_with_messages
# Test with invalid power name
targets = await analyze_recent_messages_for_targeting(client=client, power_name="INVALID_POWER", max_messages=20)
# Should still return a valid list (fallback behavior)
assert isinstance(targets, list)
class TestNegotiationTiming:
"""Test negotiation delay and timing logic."""
def test_negotiation_delay_calculation(self):
"""Test that negotiation delays are calculated correctly."""
# Test first round gets extra time
first_round_delay = get_negotiation_delay(round_number=1, total_rounds=3)
base_delay = get_negotiation_delay(round_number=2, total_rounds=3)
final_round_delay = get_negotiation_delay(round_number=3, total_rounds=3)
assert first_round_delay > base_delay, "First round should have longer delay"
assert final_round_delay < base_delay, "Final round should have shorter delay"
# All delays should be positive
assert first_round_delay > 0
assert base_delay > 0
assert final_round_delay > 0
def test_negotiation_delay_edge_cases(self):
"""Test delay calculation with edge cases."""
# Single round
single_delay = get_negotiation_delay(round_number=1, total_rounds=1)
assert single_delay > 0
# Many rounds
many_rounds_delay = get_negotiation_delay(round_number=5, total_rounds=10)
assert many_rounds_delay > 0
class TestNegotiationParticipation:
"""Test logic for determining negotiation participation."""
@pytest.fixture
async def client_and_mocked_agent(self, fake_server):
"""Client with mocked agent for participation testing."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("participation_user", "password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS"],
power_name="RUSSIA",
n_controls=1,
)
await client.synchronize()
# Create mock agent
mock_agent = MagicMock(spec=DiplomacyAgent)
mock_agent.power_name = "RUSSIA"
yield {"client": client, "agent": mock_agent}
finally:
try:
await client.close()
except:
pass
async def test_should_participate_eliminated_power(self, client_and_mocked_agent):
"""Test that eliminated powers don't participate in negotiations."""
client = client_and_mocked_agent["client"]
agent = client_and_mocked_agent["agent"]
# Mock the power as eliminated
with patch.object(client, "get_power") as mock_get_power:
mock_power = MagicMock()
mock_power.is_eliminated.return_value = True
mock_get_power.return_value = mock_power
should_participate = await should_participate_in_negotiations(client, agent)
assert should_participate is False
async def test_should_participate_no_orders(self, client_and_mocked_agent):
"""Test that powers with no orderable locations don't negotiate."""
client = client_and_mocked_agent["client"]
agent = client_and_mocked_agent["agent"]
# Mock the power as not eliminated but with no orderable locations
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = [] # No possible orders
should_participate = await should_participate_in_negotiations(client, agent)
assert should_participate is False
async def test_should_participate_non_movement_phase(self, client_and_mocked_agent):
"""Test that powers don't negotiate in non-movement phases."""
client = client_and_mocked_agent["client"]
agent = client_and_mocked_agent["agent"]
# Mock retreat phase
with (
patch.object(client, "get_power") as mock_get_power,
patch.object(client, "get_current_short_phase") as mock_phase,
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
):
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Moscow - Hold"] # Has orders
mock_phase.return_value = "S1901R" # Retreat phase
should_participate = await should_participate_in_negotiations(client, agent)
assert should_participate is False
async def test_should_participate_movement_phase_with_orders(self, client_and_mocked_agent):
"""Test that active powers with orders participate in movement phases."""
client = client_and_mocked_agent["client"]
agent = client_and_mocked_agent["agent"]
# Mock movement phase with orderable locations
with (
patch.object(client, "get_power") as mock_get_power,
patch.object(client, "get_current_short_phase") as mock_phase,
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
):
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Moscow - Hold", "F Sevastopol - Black Sea"]
mock_phase.return_value = "S1901M" # Movement phase
should_participate = await should_participate_in_negotiations(client, agent)
assert should_participate is True
class TestStrategicNegotiationRound:
"""Test the complete strategic negotiation round functionality."""
@pytest.fixture
async def negotiation_setup(self, fake_server):
"""Setup for testing complete negotiation rounds."""
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
try:
await client.connect_and_authenticate("negotiation_user", "password")
await client.create_game(
map_name="standard",
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
power_name="TURKEY",
n_controls=1,
)
await client.synchronize()
# Create mock agent with necessary attributes
mock_agent = MagicMock(spec=DiplomacyAgent)
mock_agent.power_name = "TURKEY"
mock_agent.goals = "Expand into the Mediterranean"
mock_agent.relationships = {}
mock_agent.format_private_diary_for_prompt.return_value = "Test diary entry"
# Mock AI client for message generation
mock_ai_client = AsyncMock()
mock_agent.client = mock_ai_client
# Create game history
game_history = GameHistory()
# Mock error stats
error_stats = {"test_model": {"conversation_errors": 0, "order_decoding_errors": 0}}
yield {"client": client, "agent": mock_agent, "game_history": game_history, "error_stats": error_stats}
finally:
try:
await client.close()
except:
pass
async def test_successful_negotiation_round(self, negotiation_setup):
"""Test a successful negotiation round with message generation."""
setup = negotiation_setup
client = setup["client"]
agent = setup["agent"]
game_history = setup["game_history"]
error_stats = setup["error_stats"]
# Mock AI response for message generation
mock_messages = [
{"content": "Greetings! Turkey seeks peaceful relations.", "message_type": "global"},
{"content": "Russia, shall we coordinate our efforts?", "message_type": "private", "recipient": "RUSSIA"},
]
agent.client.get_conversation_reply.return_value = mock_messages
# Mock power and orders
with (
patch.object(client, "get_power") as mock_get_power,
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
patch.object(client, "send_message") as mock_send,
):
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Constantinople - Hold"]
# Run negotiation round
success = await conduct_strategic_negotiation_round(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=error_stats,
log_file_path="/tmp/test_log.txt",
round_number=1,
max_rounds=3,
)
assert success is True
# Should have called send_message for each generated message
assert mock_send.call_count == len(mock_messages)
async def test_negotiation_round_no_messages(self, negotiation_setup):
"""Test negotiation round when AI generates no messages."""
setup = negotiation_setup
client = setup["client"]
agent = setup["agent"]
game_history = setup["game_history"]
error_stats = setup["error_stats"]
# Mock AI response with no messages
agent.client.get_conversation_reply.return_value = []
# Mock power and orders
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Constantinople - Hold"]
# Run negotiation round
success = await conduct_strategic_negotiation_round(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=error_stats,
log_file_path="/tmp/test_log.txt",
round_number=1,
max_rounds=3,
)
assert success is False
async def test_negotiation_round_ai_error(self, negotiation_setup):
"""Test negotiation round when AI client raises an exception."""
setup = negotiation_setup
client = setup["client"]
agent = setup["agent"]
game_history = setup["game_history"]
error_stats = setup["error_stats"]
# Mock AI client to raise exception
agent.client.get_conversation_reply.side_effect = Exception("AI service unavailable")
agent.client.model_name = "test_model"
# Mock power and orders
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Constantinople - Hold"]
# Run negotiation round
success = await conduct_strategic_negotiation_round(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=error_stats,
log_file_path="/tmp/test_log.txt",
round_number=1,
max_rounds=3,
)
assert success is False
# Error should be tracked in statistics
assert error_stats["test_model"]["conversation_errors"] == 1
async def test_negotiation_round_with_targeting(self, negotiation_setup):
"""Test that negotiation round uses strategic targeting."""
setup = negotiation_setup
client = setup["client"]
agent = setup["agent"]
game_history = setup["game_history"]
error_stats = setup["error_stats"]
# Mock message targeting analysis
with patch("websocket_negotiations.analyze_recent_messages_for_targeting") as mock_targeting:
mock_targeting.return_value = ["RUSSIA", "AUSTRIA", "ITALY"]
# Mock AI response
mock_messages = [{"content": "Test message", "message_type": "global"}]
agent.client.get_conversation_reply.return_value = mock_messages
# Mock power and orders
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_get_power.return_value = mock_power
mock_orders.return_value = ["A Constantinople - Hold"]
# Run negotiation round 2 (should use targeting)
await conduct_strategic_negotiation_round(
client=client,
agent=agent,
game_history=game_history,
model_error_stats=error_stats,
log_file_path="/tmp/test_log.txt",
round_number=2, # Round > 1 should use targeting
max_rounds=3,
)
# Should have called the targeting analysis
mock_targeting.assert_called_once_with(client, "TURKEY")

View file

@ -19,16 +19,14 @@ from models import (
serialize_message,
parse_message,
)
from typed_websocket_client import TypedWebSocketDiplomacyClient
from websocket_diplomacy_client import WebSocketDiplomacyClient
class TestAuthentication:
"""Test authentication flow and message handling."""
@pytest.mark.websocket
async def test_successful_authentication(
self, client: TypedWebSocketDiplomacyClient, helpers
):
async def test_successful_authentication(self, client: WebSocketDiplomacyClient, helpers):
"""Test successful authentication with valid credentials."""
await client.connect_and_authenticate("test_user", "test_password")
helpers.assert_valid_token(client.token)
@ -44,16 +42,12 @@ class TestAuthentication:
assert client.token is None
@pytest.mark.websocket
async def test_authentication_message_format(
self, credentials: Dict[str, str], client: TypedWebSocketDiplomacyClient
):
async def test_authentication_message_format(self, credentials: Dict[str, str], client: WebSocketDiplomacyClient):
"""Test that authentication messages are properly formatted."""
await client.connect_and_authenticate(**credentials)
# Create a sign-in request manually to test message format
request = SignInRequest(
request_id=str(uuid.uuid4()), username="test_user", password="test_password"
)
request = SignInRequest(request_id=str(uuid.uuid4()), username="test_user", password="test_password")
# Verify the request serializes correctly
message_data = serialize_message(request)
@ -74,9 +68,7 @@ class TestGameOperations:
@pytest.mark.websocket
async def test_create_game(self, authenticated_client, helpers):
"""Test creating a new game."""
game_data = await authenticated_client.create_game(
power_name="FRANCE", n_controls=1
)
game_data = await authenticated_client.create_game(power_name="FRANCE", n_controls=1)
helpers.assert_valid_game_data(game_data)
assert authenticated_client.game_id is not None
@ -98,9 +90,7 @@ class TestGameOperations:
assert authenticated_client.game_id in game_ids
@pytest.mark.websocket
async def test_join_existing_game(
self, client: TypedWebSocketDiplomacyClient, helpers
):
async def test_join_existing_game(self, client: WebSocketDiplomacyClient, helpers):
"""Test joining an existing game."""
# Connect and authenticate a first client
await client.connect()
@ -130,9 +120,7 @@ class TestGameOperations:
async def test_join_nonexistent_game(self, authenticated_client):
"""Test joining a game that doesn't exist."""
with pytest.raises(ValueError, match="Game .* not found"):
await authenticated_client.join_game(
"NONEXISTENT_GAME", power_name="FRANCE"
)
await authenticated_client.join_game("NONEXISTENT_GAME", power_name="FRANCE")
class TestGamePlay:
@ -246,9 +234,7 @@ class TestMessageValidation:
"""Test that messages can be serialized and parsed correctly."""
# Test various message types
messages = [
SignInRequest(
request_id=str(uuid.uuid4()), username="test", password="pass"
),
SignInRequest(request_id=str(uuid.uuid4()), username="test", password="pass"),
CreateGameRequest(
request_id=str(uuid.uuid4()),
token="test_token",
@ -289,7 +275,7 @@ class TestConcurrentOperations:
try:
# Create multiple clients
for i in range(3):
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
await client.connect()
await client.authenticate("test_user", "test_password")
clients.append(client)
@ -300,9 +286,7 @@ class TestConcurrentOperations:
# Each can create games independently
for i, client in enumerate(clients):
game_data = await client.create_game(
power_name="FRANCE" if i == 0 else None, n_controls=1
)
game_data = await client.create_game(power_name="FRANCE" if i == 0 else None, n_controls=1)
assert game_data["game_id"] is not None
finally:

View file

@ -36,3 +36,4 @@ dependencies = [
[tool.ruff]
line-length = 150
exclude = ["diplomacy", "docs"]
line-length = 150