AI_Diplomacy/bot_client/tests/test_message_functionality_unit.py
Tyler Marques 79a31d615d
WIP: Continual work on the bot client. They successfully play the game, but don't negotiate
Signed-off-by: Tyler Marques <me@tylermarques.com>
2025-07-29 16:00:42 -07:00

356 lines
15 KiB
Python

"""
Unit tests for the new inter-power messaging functionality.
These tests focus on testing the specific functionality we added:
1. Enhanced WebSocketDiplomacyClient methods
2. Message tracking and statistics in SingleBotPlayer
3. Negotiation targeting logic
These are more focused unit tests that don't require full server integration.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from typing import Dict, List
from websocket_diplomacy_client import WebSocketDiplomacyClient
from single_bot_player import SingleBotPlayer
from websocket_negotiations import (
analyze_recent_messages_for_targeting,
should_participate_in_negotiations,
get_negotiation_delay,
)
from diplomacy.engine.message import Message, GLOBAL
class TestWebSocketClientEnhancements:
"""Test the enhanced methods we added to WebSocketDiplomacyClient."""
@pytest.fixture
def mock_client(self):
"""Create a mocked WebSocketDiplomacyClient for testing."""
client = WebSocketDiplomacyClient()
# Mock the game and related objects
client.game = MagicMock()
mock_powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: True), # Eliminated
}
client.game.powers = mock_powers
# Mock message history
mock_messages = [
Message(sender="ENGLAND", recipient="FRANCE", message="Hello France!", phase="S1901M"),
Message(sender="GERMANY", recipient="FRANCE", message="Greetings!", phase="S1901M"),
Message(sender="FRANCE", recipient="GLOBAL", message="Hello everyone!", phase="S1901M"),
]
client.game.messages = {i: msg for i, msg in enumerate(mock_messages)}
return client
async def test_send_broadcast_message(self, mock_client):
"""Test the new send_broadcast_message method."""
# Mock the send_message method and powers property
mock_client.send_message = AsyncMock()
# Mock the powers property to return our mock powers
with patch.object(type(mock_client), "powers", new_callable=lambda: property(lambda self: self.game.powers)):
# Send broadcast message
await mock_client.send_broadcast_message(sender="FRANCE", message="Hello to all active powers!")
# Should have called send_message for each active power (excluding sender and eliminated)
expected_calls = 2 # ENGLAND and GERMANY (not ITALY because eliminated, not FRANCE because sender)
assert mock_client.send_message.call_count == expected_calls
# Check that messages were sent to the right powers
call_args_list = mock_client.send_message.call_args_list
recipients = [call[1]["recipient"] for call in call_args_list] # Get recipient from kwargs
assert "ENGLAND" in recipients
assert "GERMANY" in recipients
assert "FRANCE" not in recipients # Shouldn't send to self
assert "ITALY" not in recipients # Shouldn't send to eliminated power
async def test_get_recent_messages_filtering(self, mock_client):
"""Test the new get_recent_messages method with filtering."""
# Mock current phase
mock_client.get_current_short_phase = MagicMock(return_value="S1901M")
# Mock the messages property to return our mock messages
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
# Test getting recent messages
recent_messages = await mock_client.get_recent_messages(limit=5)
# Should return the mocked messages
assert len(recent_messages) <= 5
assert len(recent_messages) == 3 # We have 3 mock messages
# Test phase filtering
phase_messages = await mock_client.get_recent_messages(phase="S1901M", limit=10)
assert all(msg.phase == "S1901M" for msg in phase_messages)
async def test_get_recent_messages_limit(self, mock_client):
"""Test that the limit parameter works correctly."""
# Mock the messages property to return our mock messages
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
# Test with limit smaller than available messages
limited_messages = await mock_client.get_recent_messages(limit=2)
assert len(limited_messages) <= 2
# Test with limit larger than available messages
all_messages = await mock_client.get_recent_messages(limit=100)
assert len(all_messages) == 3 # Should not exceed available messages
class TestSingleBotPlayerMessageTracking:
"""Test the message tracking functionality in SingleBotPlayer."""
@pytest.fixture
def mock_bot_player(self):
"""Create a mock SingleBotPlayer for testing."""
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="FRANCE", model_name="test_model")
# Mock dependencies
bot.client = MagicMock()
bot.agent = MagicMock()
bot.game_history = MagicMock()
# Set up powers
bot.client.powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: False),
}
return bot
def test_message_counting_initialization(self, mock_bot_player):
"""Test that message tracking starts in clean state."""
bot = mock_bot_player
assert isinstance(bot.message_counts, dict)
assert isinstance(bot.response_counts, dict)
assert isinstance(bot.priority_contacts, list)
assert len(bot.message_counts) == 0
assert len(bot.response_counts) == 0
assert len(bot.priority_contacts) == 0
def test_priority_contact_updates(self, mock_bot_player):
"""Test that priority contacts are updated correctly."""
bot = mock_bot_player
# Set up message counts
bot.message_counts = {
"ENGLAND": 5,
"GERMANY": 3,
"ITALY": 2,
"AUSTRIA": 1,
"RUSSIA": 4,
}
# Update priority contacts
bot._update_priority_contacts()
# Should have top 4 contacts in order of activity
assert len(bot.priority_contacts) == 4
assert bot.priority_contacts[0] == "ENGLAND" # Highest count (5)
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
def test_message_statistics_generation(self, mock_bot_player):
"""Test generation of message statistics."""
bot = mock_bot_player
bot.current_phase = "S1901M"
# Set up data
bot.message_counts = {"ENGLAND": 4, "GERMANY": 2}
bot.response_counts = {"ENGLAND": 3, "GERMANY": 1}
bot.priority_contacts = ["ENGLAND", "GERMANY"]
# Generate statistics
stats = bot.get_message_statistics()
# Check basic structure
assert stats["power_name"] == "FRANCE"
assert stats["current_phase"] == "S1901M"
assert stats["total_messages_received"] == 6 # 4 + 2
assert stats["total_responses_sent"] == 4 # 3 + 1
# Check response rates
assert stats["response_rate_by_power"]["ENGLAND"] == 0.75 # 3/4
assert stats["response_rate_by_power"]["GERMANY"] == 0.5 # 1/2
def test_response_decision_logic(self, mock_bot_player):
"""Test the enhanced response decision logic."""
bot = mock_bot_player
bot.priority_contacts = ["ENGLAND"]
# Create test messages
priority_message = MagicMock()
priority_message.sender = "ENGLAND"
priority_message.recipient = "FRANCE"
priority_message.message = "Hello France!"
non_priority_message = MagicMock()
non_priority_message.sender = "GERMANY"
non_priority_message.recipient = "FRANCE"
non_priority_message.message = "Hello France!"
# Test the decision logic (extracted from _consider_message_response)
def should_respond(message):
message_lower = message.message.lower()
strategic_keywords = ["alliance", "deal", "propose", "agreement"]
return any(
[
"?" in message.message,
any(word in message_lower for word in ["hello", "hi", "greetings"]),
any(keyword in message_lower for keyword in strategic_keywords),
len(message.message.split()) > 15,
message.sender in bot.priority_contacts,
]
)
# Both should respond due to "hello", but priority logic is tested
assert should_respond(priority_message) is True
assert should_respond(non_priority_message) is True
# Test priority contact influence
assert priority_message.sender in bot.priority_contacts
assert non_priority_message.sender not in bot.priority_contacts
class TestNegotiationTargeting:
"""Test the strategic negotiation targeting logic."""
@pytest.fixture
def mock_client_with_messages(self):
"""Create a mock client with message history."""
client = MagicMock()
# Mock recent messages
mock_messages = [
Message(sender="ENGLAND", recipient="FRANCE", message="Direct to France", phase="S1901M"),
Message(sender="ENGLAND", recipient="FRANCE", message="Another to France", phase="S1901M"),
Message(sender="GERMANY", recipient="FRANCE", message="Message to France", phase="S1901M"),
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"),
]
client.get_recent_messages = AsyncMock(return_value=mock_messages)
client.powers = {
"FRANCE": MagicMock(is_eliminated=lambda: False),
"ENGLAND": MagicMock(is_eliminated=lambda: False),
"GERMANY": MagicMock(is_eliminated=lambda: False),
"ITALY": MagicMock(is_eliminated=lambda: False),
}
return client
async def test_analyze_recent_messages_for_targeting(self, mock_client_with_messages):
"""Test the message analysis for targeting."""
client = mock_client_with_messages
# Analyze targeting for FRANCE
targets = await analyze_recent_messages_for_targeting(client, "FRANCE", max_messages=20)
# Should return a list of powers
assert isinstance(targets, list)
assert "FRANCE" not in targets # Should not include self
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
# GERMANY should be second (sent 1 direct message to FRANCE)
if len(targets) >= 2:
assert targets[0] == "ENGLAND" # Most direct messages
if "GERMANY" in targets:
germany_index = targets.index("GERMANY")
england_index = targets.index("ENGLAND")
assert england_index < germany_index # England should come before Germany
def test_negotiation_delay_calculation(self):
"""Test that negotiation delays are calculated correctly."""
# Test different round scenarios
first_delay = get_negotiation_delay(round_number=1, total_rounds=3)
middle_delay = get_negotiation_delay(round_number=2, total_rounds=3)
final_delay = get_negotiation_delay(round_number=3, total_rounds=3)
# First round should have longer delay
assert first_delay > middle_delay
# Final round should have shorter delay
assert final_delay < middle_delay
# All delays should be positive
assert all(delay > 0 for delay in [first_delay, middle_delay, final_delay])
async def test_should_participate_in_negotiations(self):
"""Test negotiation participation logic."""
# Mock client and agent
mock_client = MagicMock()
mock_agent = MagicMock()
mock_agent.power_name = "FRANCE"
# Test case: active power in movement phase
mock_power = MagicMock()
mock_power.is_eliminated.return_value = False
mock_client.get_power.return_value = mock_power
mock_client.get_current_short_phase.return_value = "S1901M"
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_orders.return_value = ["A Paris - Hold"] # Has orders
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is True
# Test case: eliminated power
mock_power.is_eliminated.return_value = True
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is False
# Test case: non-movement phase
mock_power.is_eliminated.return_value = False
mock_client.get_current_short_phase.return_value = "S1901R" # Retreat phase
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
mock_orders.return_value = ["A Paris - Hold"]
result = await should_participate_in_negotiations(mock_client, mock_agent)
assert result is False
class TestIntegrationScenarios:
"""Test integration scenarios with mocked components."""
def test_message_persistence_across_phases(self):
"""Test that message counts persist across multiple game phases."""
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="AUSTRIA", model_name="test_model")
# Mock dependencies
bot.client = MagicMock()
bot.agent = MagicMock()
bot.game_history = MagicMock()
# Simulate message accumulation over phases
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
for phase in phases:
bot.current_phase = phase
# Simulate receiving messages
if phase.endswith("M"): # Movement phases
for sender in ["FRANCE", "ENGLAND"]:
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
else:
# Fewer messages in other phases
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
bot._update_priority_contacts()
# Check accumulated counts
assert bot.message_counts["FRANCE"] >= 5 # Should have accumulated messages
assert bot.message_counts["ENGLAND"] >= 4
assert len(bot.priority_contacts) > 0
assert bot.priority_contacts[0] == "FRANCE" # Should be top priority