mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
356 lines
15 KiB
Python
356 lines
15 KiB
Python
"""
|
|
Unit tests for the new inter-power messaging functionality.
|
|
|
|
These tests focus on testing the specific functionality we added:
|
|
1. Enhanced WebSocketDiplomacyClient methods
|
|
2. Message tracking and statistics in SingleBotPlayer
|
|
3. Negotiation targeting logic
|
|
|
|
These are more focused unit tests that don't require full server integration.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
from typing import Dict, List
|
|
|
|
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
|
from single_bot_player import SingleBotPlayer
|
|
from websocket_negotiations import (
|
|
analyze_recent_messages_for_targeting,
|
|
should_participate_in_negotiations,
|
|
get_negotiation_delay,
|
|
)
|
|
from diplomacy.engine.message import Message, GLOBAL
|
|
|
|
|
|
class TestWebSocketClientEnhancements:
|
|
"""Test the enhanced methods we added to WebSocketDiplomacyClient."""
|
|
|
|
@pytest.fixture
|
|
def mock_client(self):
|
|
"""Create a mocked WebSocketDiplomacyClient for testing."""
|
|
client = WebSocketDiplomacyClient()
|
|
|
|
# Mock the game and related objects
|
|
client.game = MagicMock()
|
|
mock_powers = {
|
|
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
|
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
|
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
|
"ITALY": MagicMock(is_eliminated=lambda: True), # Eliminated
|
|
}
|
|
client.game.powers = mock_powers
|
|
|
|
# Mock message history
|
|
mock_messages = [
|
|
Message(sender="ENGLAND", recipient="FRANCE", message="Hello France!", phase="S1901M"),
|
|
Message(sender="GERMANY", recipient="FRANCE", message="Greetings!", phase="S1901M"),
|
|
Message(sender="FRANCE", recipient="GLOBAL", message="Hello everyone!", phase="S1901M"),
|
|
]
|
|
client.game.messages = {i: msg for i, msg in enumerate(mock_messages)}
|
|
|
|
return client
|
|
|
|
async def test_send_broadcast_message(self, mock_client):
|
|
"""Test the new send_broadcast_message method."""
|
|
# Mock the send_message method and powers property
|
|
mock_client.send_message = AsyncMock()
|
|
|
|
# Mock the powers property to return our mock powers
|
|
with patch.object(type(mock_client), "powers", new_callable=lambda: property(lambda self: self.game.powers)):
|
|
# Send broadcast message
|
|
await mock_client.send_broadcast_message(sender="FRANCE", message="Hello to all active powers!")
|
|
|
|
# Should have called send_message for each active power (excluding sender and eliminated)
|
|
expected_calls = 2 # ENGLAND and GERMANY (not ITALY because eliminated, not FRANCE because sender)
|
|
assert mock_client.send_message.call_count == expected_calls
|
|
|
|
# Check that messages were sent to the right powers
|
|
call_args_list = mock_client.send_message.call_args_list
|
|
recipients = [call[1]["recipient"] for call in call_args_list] # Get recipient from kwargs
|
|
|
|
assert "ENGLAND" in recipients
|
|
assert "GERMANY" in recipients
|
|
assert "FRANCE" not in recipients # Shouldn't send to self
|
|
assert "ITALY" not in recipients # Shouldn't send to eliminated power
|
|
|
|
async def test_get_recent_messages_filtering(self, mock_client):
|
|
"""Test the new get_recent_messages method with filtering."""
|
|
# Mock current phase
|
|
mock_client.get_current_short_phase = MagicMock(return_value="S1901M")
|
|
|
|
# Mock the messages property to return our mock messages
|
|
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
|
# Test getting recent messages
|
|
recent_messages = await mock_client.get_recent_messages(limit=5)
|
|
|
|
# Should return the mocked messages
|
|
assert len(recent_messages) <= 5
|
|
assert len(recent_messages) == 3 # We have 3 mock messages
|
|
|
|
# Test phase filtering
|
|
phase_messages = await mock_client.get_recent_messages(phase="S1901M", limit=10)
|
|
assert all(msg.phase == "S1901M" for msg in phase_messages)
|
|
|
|
async def test_get_recent_messages_limit(self, mock_client):
|
|
"""Test that the limit parameter works correctly."""
|
|
# Mock the messages property to return our mock messages
|
|
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
|
# Test with limit smaller than available messages
|
|
limited_messages = await mock_client.get_recent_messages(limit=2)
|
|
assert len(limited_messages) <= 2
|
|
|
|
# Test with limit larger than available messages
|
|
all_messages = await mock_client.get_recent_messages(limit=100)
|
|
assert len(all_messages) == 3 # Should not exceed available messages
|
|
|
|
|
|
class TestSingleBotPlayerMessageTracking:
|
|
"""Test the message tracking functionality in SingleBotPlayer."""
|
|
|
|
@pytest.fixture
|
|
def mock_bot_player(self):
|
|
"""Create a mock SingleBotPlayer for testing."""
|
|
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="FRANCE", model_name="test_model")
|
|
|
|
# Mock dependencies
|
|
bot.client = MagicMock()
|
|
bot.agent = MagicMock()
|
|
bot.game_history = MagicMock()
|
|
|
|
# Set up powers
|
|
bot.client.powers = {
|
|
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
|
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
|
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
|
"ITALY": MagicMock(is_eliminated=lambda: False),
|
|
}
|
|
|
|
return bot
|
|
|
|
def test_message_counting_initialization(self, mock_bot_player):
|
|
"""Test that message tracking starts in clean state."""
|
|
bot = mock_bot_player
|
|
|
|
assert isinstance(bot.message_counts, dict)
|
|
assert isinstance(bot.response_counts, dict)
|
|
assert isinstance(bot.priority_contacts, list)
|
|
|
|
assert len(bot.message_counts) == 0
|
|
assert len(bot.response_counts) == 0
|
|
assert len(bot.priority_contacts) == 0
|
|
|
|
def test_priority_contact_updates(self, mock_bot_player):
|
|
"""Test that priority contacts are updated correctly."""
|
|
bot = mock_bot_player
|
|
|
|
# Set up message counts
|
|
bot.message_counts = {
|
|
"ENGLAND": 5,
|
|
"GERMANY": 3,
|
|
"ITALY": 2,
|
|
"AUSTRIA": 1,
|
|
"RUSSIA": 4,
|
|
}
|
|
|
|
# Update priority contacts
|
|
bot._update_priority_contacts()
|
|
|
|
# Should have top 4 contacts in order of activity
|
|
assert len(bot.priority_contacts) == 4
|
|
assert bot.priority_contacts[0] == "ENGLAND" # Highest count (5)
|
|
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
|
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
|
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
|
|
|
def test_message_statistics_generation(self, mock_bot_player):
|
|
"""Test generation of message statistics."""
|
|
bot = mock_bot_player
|
|
bot.current_phase = "S1901M"
|
|
|
|
# Set up data
|
|
bot.message_counts = {"ENGLAND": 4, "GERMANY": 2}
|
|
bot.response_counts = {"ENGLAND": 3, "GERMANY": 1}
|
|
bot.priority_contacts = ["ENGLAND", "GERMANY"]
|
|
|
|
# Generate statistics
|
|
stats = bot.get_message_statistics()
|
|
|
|
# Check basic structure
|
|
assert stats["power_name"] == "FRANCE"
|
|
assert stats["current_phase"] == "S1901M"
|
|
assert stats["total_messages_received"] == 6 # 4 + 2
|
|
assert stats["total_responses_sent"] == 4 # 3 + 1
|
|
|
|
# Check response rates
|
|
assert stats["response_rate_by_power"]["ENGLAND"] == 0.75 # 3/4
|
|
assert stats["response_rate_by_power"]["GERMANY"] == 0.5 # 1/2
|
|
|
|
def test_response_decision_logic(self, mock_bot_player):
|
|
"""Test the enhanced response decision logic."""
|
|
bot = mock_bot_player
|
|
bot.priority_contacts = ["ENGLAND"]
|
|
|
|
# Create test messages
|
|
priority_message = MagicMock()
|
|
priority_message.sender = "ENGLAND"
|
|
priority_message.recipient = "FRANCE"
|
|
priority_message.message = "Hello France!"
|
|
|
|
non_priority_message = MagicMock()
|
|
non_priority_message.sender = "GERMANY"
|
|
non_priority_message.recipient = "FRANCE"
|
|
non_priority_message.message = "Hello France!"
|
|
|
|
# Test the decision logic (extracted from _consider_message_response)
|
|
def should_respond(message):
|
|
message_lower = message.message.lower()
|
|
strategic_keywords = ["alliance", "deal", "propose", "agreement"]
|
|
|
|
return any(
|
|
[
|
|
"?" in message.message,
|
|
any(word in message_lower for word in ["hello", "hi", "greetings"]),
|
|
any(keyword in message_lower for keyword in strategic_keywords),
|
|
len(message.message.split()) > 15,
|
|
message.sender in bot.priority_contacts,
|
|
]
|
|
)
|
|
|
|
# Both should respond due to "hello", but priority logic is tested
|
|
assert should_respond(priority_message) is True
|
|
assert should_respond(non_priority_message) is True
|
|
|
|
# Test priority contact influence
|
|
assert priority_message.sender in bot.priority_contacts
|
|
assert non_priority_message.sender not in bot.priority_contacts
|
|
|
|
|
|
class TestNegotiationTargeting:
|
|
"""Test the strategic negotiation targeting logic."""
|
|
|
|
@pytest.fixture
|
|
def mock_client_with_messages(self):
|
|
"""Create a mock client with message history."""
|
|
client = MagicMock()
|
|
|
|
# Mock recent messages
|
|
mock_messages = [
|
|
Message(sender="ENGLAND", recipient="FRANCE", message="Direct to France", phase="S1901M"),
|
|
Message(sender="ENGLAND", recipient="FRANCE", message="Another to France", phase="S1901M"),
|
|
Message(sender="GERMANY", recipient="FRANCE", message="Message to France", phase="S1901M"),
|
|
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"),
|
|
]
|
|
|
|
client.get_recent_messages = AsyncMock(return_value=mock_messages)
|
|
client.powers = {
|
|
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
|
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
|
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
|
"ITALY": MagicMock(is_eliminated=lambda: False),
|
|
}
|
|
|
|
return client
|
|
|
|
async def test_analyze_recent_messages_for_targeting(self, mock_client_with_messages):
|
|
"""Test the message analysis for targeting."""
|
|
client = mock_client_with_messages
|
|
|
|
# Analyze targeting for FRANCE
|
|
targets = await analyze_recent_messages_for_targeting(client, "FRANCE", max_messages=20)
|
|
|
|
# Should return a list of powers
|
|
assert isinstance(targets, list)
|
|
assert "FRANCE" not in targets # Should not include self
|
|
|
|
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
|
|
# GERMANY should be second (sent 1 direct message to FRANCE)
|
|
if len(targets) >= 2:
|
|
assert targets[0] == "ENGLAND" # Most direct messages
|
|
if "GERMANY" in targets:
|
|
germany_index = targets.index("GERMANY")
|
|
england_index = targets.index("ENGLAND")
|
|
assert england_index < germany_index # England should come before Germany
|
|
|
|
def test_negotiation_delay_calculation(self):
|
|
"""Test that negotiation delays are calculated correctly."""
|
|
# Test different round scenarios
|
|
first_delay = get_negotiation_delay(round_number=1, total_rounds=3)
|
|
middle_delay = get_negotiation_delay(round_number=2, total_rounds=3)
|
|
final_delay = get_negotiation_delay(round_number=3, total_rounds=3)
|
|
|
|
# First round should have longer delay
|
|
assert first_delay > middle_delay
|
|
# Final round should have shorter delay
|
|
assert final_delay < middle_delay
|
|
# All delays should be positive
|
|
assert all(delay > 0 for delay in [first_delay, middle_delay, final_delay])
|
|
|
|
async def test_should_participate_in_negotiations(self):
|
|
"""Test negotiation participation logic."""
|
|
# Mock client and agent
|
|
mock_client = MagicMock()
|
|
mock_agent = MagicMock()
|
|
mock_agent.power_name = "FRANCE"
|
|
|
|
# Test case: active power in movement phase
|
|
mock_power = MagicMock()
|
|
mock_power.is_eliminated.return_value = False
|
|
mock_client.get_power.return_value = mock_power
|
|
mock_client.get_current_short_phase.return_value = "S1901M"
|
|
|
|
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
|
mock_orders.return_value = ["A Paris - Hold"] # Has orders
|
|
|
|
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
|
assert result is True
|
|
|
|
# Test case: eliminated power
|
|
mock_power.is_eliminated.return_value = True
|
|
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
|
assert result is False
|
|
|
|
# Test case: non-movement phase
|
|
mock_power.is_eliminated.return_value = False
|
|
mock_client.get_current_short_phase.return_value = "S1901R" # Retreat phase
|
|
|
|
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
|
mock_orders.return_value = ["A Paris - Hold"]
|
|
|
|
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
|
assert result is False
|
|
|
|
|
|
class TestIntegrationScenarios:
|
|
"""Test integration scenarios with mocked components."""
|
|
|
|
def test_message_persistence_across_phases(self):
|
|
"""Test that message counts persist across multiple game phases."""
|
|
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="AUSTRIA", model_name="test_model")
|
|
|
|
# Mock dependencies
|
|
bot.client = MagicMock()
|
|
bot.agent = MagicMock()
|
|
bot.game_history = MagicMock()
|
|
|
|
# Simulate message accumulation over phases
|
|
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
|
|
|
for phase in phases:
|
|
bot.current_phase = phase
|
|
|
|
# Simulate receiving messages
|
|
if phase.endswith("M"): # Movement phases
|
|
for sender in ["FRANCE", "ENGLAND"]:
|
|
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
|
else:
|
|
# Fewer messages in other phases
|
|
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
|
|
|
bot._update_priority_contacts()
|
|
|
|
# Check accumulated counts
|
|
assert bot.message_counts["FRANCE"] >= 5 # Should have accumulated messages
|
|
assert bot.message_counts["ENGLAND"] >= 4
|
|
assert len(bot.priority_contacts) > 0
|
|
assert bot.priority_contacts[0] == "FRANCE" # Should be top priority
|