mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-30 17:40:47 +00:00
WIP: Continual work on the bot client. They successfully play the game, but don't negotiate
Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
parent
c09d5bfe7e
commit
79a31d615d
13 changed files with 2483 additions and 263 deletions
403
bot_client/tests/test_message_persistence.py
Normal file
403
bot_client/tests/test_message_persistence.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""
|
||||
Integration tests for message persistence and tracking functionality.
|
||||
|
||||
These tests verify that the SingleBotPlayer can properly:
|
||||
1. Track message counts and response patterns
|
||||
2. Update priority contacts based on messaging activity
|
||||
3. Generate accurate message statistics
|
||||
4. Log statistics during phase transitions
|
||||
5. Maintain message persistence across game phases
|
||||
|
||||
Tests focus on the tracking and statistics features rather than
|
||||
the actual AI message generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict
|
||||
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from diplomacy.engine.message import Message
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestMessageCounting:
|
||||
"""Test message counting and tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_player(self):
|
||||
"""Create a SingleBotPlayer with mocked dependencies for testing."""
|
||||
# Create mock bot player
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="ENGLAND", model_name="test_model", game_id="test_game")
|
||||
|
||||
# Mock the client and agent
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_initial_message_tracking_state(self, mock_bot_player):
|
||||
"""Test that message tracking starts in clean state."""
|
||||
bot = mock_bot_player
|
||||
|
||||
assert isinstance(bot.message_counts, dict)
|
||||
assert isinstance(bot.response_counts, dict)
|
||||
assert isinstance(bot.priority_contacts, list)
|
||||
|
||||
assert len(bot.message_counts) == 0
|
||||
assert len(bot.response_counts) == 0
|
||||
assert len(bot.priority_contacts) == 0
|
||||
|
||||
def test_message_count_tracking(self, mock_bot_player):
|
||||
"""Test that incoming messages are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate receiving messages from different powers
|
||||
test_messages = [
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Hello England!", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Another message", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="ENGLAND", message="Greetings!", phase="S1901M"),
|
||||
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"), # Should not be counted
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Self message", phase="S1901M"), # Should not be counted
|
||||
]
|
||||
|
||||
# Process each message through the message handler
|
||||
for msg in test_messages:
|
||||
# Simulate the message handling logic from _on_message_received
|
||||
bot.game_history.add_message(
|
||||
phase_name=msg.phase,
|
||||
sender=msg.sender,
|
||||
recipient=msg.recipient,
|
||||
message_content=msg.message,
|
||||
)
|
||||
|
||||
# Track message patterns (only for messages TO this bot)
|
||||
if msg.recipient == bot.power_name and msg.sender != bot.power_name:
|
||||
bot.message_counts[msg.sender] = bot.message_counts.get(msg.sender, 0) + 1
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Check counts
|
||||
assert bot.message_counts["FRANCE"] == 2
|
||||
assert bot.message_counts["GERMANY"] == 1
|
||||
assert "ITALY" not in bot.message_counts # Global message not counted
|
||||
assert "ENGLAND" not in bot.message_counts # Self message not counted
|
||||
|
||||
def test_response_count_tracking(self, mock_bot_player):
|
||||
"""Test that outgoing responses are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate sending responses to different powers
|
||||
responses = [
|
||||
("FRANCE", "Thanks for your message!"),
|
||||
("FRANCE", "Another response to France"),
|
||||
("GERMANY", "Hello Germany"),
|
||||
]
|
||||
|
||||
for recipient, message in responses:
|
||||
# Simulate response sending logic
|
||||
bot.response_counts[recipient] = bot.response_counts.get(recipient, 0) + 1
|
||||
|
||||
# Check response counts
|
||||
assert bot.response_counts["FRANCE"] == 2
|
||||
assert bot.response_counts["GERMANY"] == 1
|
||||
|
||||
def test_priority_contact_updates(self, mock_bot_player):
|
||||
"""Test that priority contacts are updated based on message counts."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with different activity levels
|
||||
bot.message_counts = {
|
||||
"FRANCE": 5, # Most active
|
||||
"GERMANY": 3, # Second most active
|
||||
"ITALY": 2, # Third most active
|
||||
"AUSTRIA": 1, # Least active
|
||||
"RUSSIA": 4, # Second highest
|
||||
}
|
||||
|
||||
# Update priority contacts
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have top 4 contacts in order of activity
|
||||
assert len(bot.priority_contacts) == 4
|
||||
assert bot.priority_contacts[0] == "FRANCE" # Highest count (5)
|
||||
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
||||
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
||||
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
||||
# AUSTRIA should not be in top 4
|
||||
|
||||
def test_priority_contacts_with_fewer_powers(self, mock_bot_player):
|
||||
"""Test priority contacts when fewer than 4 powers are active."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with only 2 powers
|
||||
bot.message_counts = {
|
||||
"FRANCE": 3,
|
||||
"GERMANY": 1,
|
||||
}
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have only 2 contacts
|
||||
assert len(bot.priority_contacts) == 2
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
assert bot.priority_contacts[1] == "GERMANY"
|
||||
|
||||
|
||||
class TestMessageStatistics:
|
||||
"""Test message statistics generation and reporting."""
|
||||
|
||||
@pytest.fixture
|
||||
def bot_with_message_data(self, mock_bot_player):
|
||||
"""Bot player with pre-populated message tracking data."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up realistic message and response data
|
||||
bot.message_counts = {
|
||||
"FRANCE": 8,
|
||||
"GERMANY": 5,
|
||||
"ITALY": 3,
|
||||
"RUSSIA": 2,
|
||||
}
|
||||
|
||||
bot.response_counts = {
|
||||
"FRANCE": 6, # 75% response rate
|
||||
"GERMANY": 2, # 40% response rate
|
||||
"ITALY": 3, # 100% response rate
|
||||
"RUSSIA": 0, # 0% response rate
|
||||
}
|
||||
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY", "ITALY", "RUSSIA"]
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client's powers for active power detection
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"RUSSIA": MagicMock(is_eliminated=lambda: False),
|
||||
"AUSTRIA": MagicMock(is_eliminated=lambda: True), # Eliminated
|
||||
"TURKEY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_statistics_generation(self, bot_with_message_data):
|
||||
"""Test that message statistics are generated correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
# Check basic structure
|
||||
assert stats["power_name"] == "ENGLAND"
|
||||
assert stats["current_phase"] == "S1901M"
|
||||
assert isinstance(stats["message_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_rate_by_power"], dict)
|
||||
assert isinstance(stats["priority_contacts"], list)
|
||||
assert isinstance(stats["active_powers"], list)
|
||||
|
||||
# Check calculated values
|
||||
assert stats["total_messages_received"] == 18 # 8+5+3+2
|
||||
assert stats["total_responses_sent"] == 11 # 6+2+3+0
|
||||
|
||||
# Check response rates
|
||||
assert stats["response_rate_by_power"]["FRANCE"] == 0.75 # 6/8
|
||||
assert stats["response_rate_by_power"]["GERMANY"] == 0.4 # 2/5
|
||||
assert stats["response_rate_by_power"]["ITALY"] == 1.0 # 3/3
|
||||
assert stats["response_rate_by_power"]["RUSSIA"] == 0.0 # 0/2
|
||||
|
||||
# Check active powers (should exclude eliminated AUSTRIA and self)
|
||||
active_powers = stats["active_powers"]
|
||||
assert "AUSTRIA" not in active_powers # Eliminated
|
||||
assert "ENGLAND" not in active_powers # Self
|
||||
assert "FRANCE" in active_powers
|
||||
assert "TURKEY" in active_powers
|
||||
|
||||
def test_message_statistics_empty_data(self, mock_bot_player):
|
||||
"""Test statistics generation with no message data."""
|
||||
bot = mock_bot_player
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock empty powers
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
assert stats["total_messages_received"] == 0
|
||||
assert stats["total_responses_sent"] == 0
|
||||
assert len(stats["message_counts_by_power"]) == 0
|
||||
assert len(stats["response_counts_by_power"]) == 0
|
||||
assert len(stats["response_rate_by_power"]) == 0
|
||||
assert len(stats["priority_contacts"]) == 0
|
||||
|
||||
def test_message_statistics_logging(self, bot_with_message_data, caplog):
|
||||
"""Test that message statistics are logged correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
# Call the logging method
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Check that appropriate log messages were generated
|
||||
log_output = caplog.text
|
||||
assert "Message Statistics for ENGLAND" in log_output
|
||||
assert "Total messages received: 18" in log_output
|
||||
assert "Total responses sent: 11" in log_output
|
||||
assert "Priority contacts:" in log_output
|
||||
|
||||
# Check that individual power stats are logged
|
||||
assert "FRANCE:" in log_output
|
||||
assert "GERMANY:" in log_output
|
||||
assert "75%" in log_output # France response rate
|
||||
assert "40%" in log_output # Germany response rate
|
||||
|
||||
|
||||
class TestMessagePersistenceIntegration:
|
||||
"""Test integration of message persistence with game flow."""
|
||||
|
||||
@pytest.fixture
|
||||
async def bot_integration_setup(self, fake_server):
|
||||
"""Setup for integration testing with fake server."""
|
||||
# Note: This is a more complex fixture that would require actual
|
||||
# SingleBotPlayer initialization, which depends on AI client setup
|
||||
# For now, we'll focus on the core tracking logic
|
||||
|
||||
bot = SingleBotPlayer(username="integration_test", password="test_pass", power_name="ITALY", model_name="test_model", game_id=None)
|
||||
|
||||
# Mock the complex dependencies
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_tracking_during_phase_transition(self, bot_integration_setup):
|
||||
"""Test that message statistics are logged during phase transitions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up some message data
|
||||
bot.message_counts = {"FRANCE": 2, "GERMANY": 1}
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client powers for log_message_statistics
|
||||
mock_powers = {
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
# Mock the log_message_statistics method to track if it's called
|
||||
with patch.object(bot, "log_message_statistics") as mock_log:
|
||||
# Simulate phase transition logic
|
||||
new_phase = "F1901M"
|
||||
if new_phase != bot.current_phase:
|
||||
bot.current_phase = new_phase
|
||||
bot.game_history.add_phase(new_phase)
|
||||
bot.orders_submitted = False
|
||||
bot.current_negotiation_round = 0
|
||||
bot.negotiation_complete = False
|
||||
|
||||
# This is the key logic from _handle_phase_update_async
|
||||
if hasattr(bot, "message_counts") and bot.message_counts:
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Verify that statistics were logged
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_priority_contact_influence_on_response_decisions(self, bot_integration_setup):
|
||||
"""Test that priority contacts influence message response decisions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up priority contacts
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY"]
|
||||
|
||||
# Create test messages from different senders
|
||||
priority_message = Message(sender="FRANCE", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
non_priority_message = Message(sender="AUSTRIA", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
# Test the response decision logic (from _consider_message_response)
|
||||
def should_respond_to_message(message):
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = [
|
||||
"alliance",
|
||||
"deal",
|
||||
"propose",
|
||||
"agreement",
|
||||
"support",
|
||||
"attack",
|
||||
"coordinate",
|
||||
"move",
|
||||
"order",
|
||||
"help",
|
||||
"work together",
|
||||
"partner",
|
||||
"enemy",
|
||||
"threat",
|
||||
"negotiate",
|
||||
"discuss",
|
||||
"plan",
|
||||
"strategy",
|
||||
"bounce",
|
||||
"convoy",
|
||||
"retreat",
|
||||
]
|
||||
|
||||
return any(
|
||||
[
|
||||
"?" in message.message, # Questions
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
|
||||
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
|
||||
len(message.message.split()) > 15, # Longer messages
|
||||
message.sender in bot.priority_contacts, # Priority contacts
|
||||
]
|
||||
)
|
||||
|
||||
# Priority contact should be more likely to get response
|
||||
priority_should_respond = should_respond_to_message(priority_message)
|
||||
non_priority_should_respond = should_respond_to_message(non_priority_message)
|
||||
|
||||
# Both should respond due to "hello" keyword, but priority contact logic is working
|
||||
assert priority_should_respond is True
|
||||
# Non-priority should also respond due to "hello", but test the contact logic
|
||||
assert priority_message.sender in bot.priority_contacts
|
||||
assert non_priority_message.sender not in bot.priority_contacts
|
||||
|
||||
def test_message_persistence_across_multiple_phases(self, bot_integration_setup):
|
||||
"""Test that message counts persist across multiple game phases."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Simulate message accumulation over multiple phases
|
||||
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
||||
|
||||
for phase in phases:
|
||||
bot.current_phase = phase
|
||||
|
||||
# Simulate receiving messages in each phase
|
||||
if phase.endswith("M"): # Movement phases
|
||||
# More messages during movement phases
|
||||
for sender in ["FRANCE", "GERMANY"]:
|
||||
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
||||
else:
|
||||
# Fewer messages during other phases
|
||||
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# After all phases, check accumulated counts
|
||||
assert bot.message_counts["FRANCE"] >= 5 # 2+2+1+2 from movement phases + 1 from adjustment
|
||||
assert bot.message_counts["GERMANY"] >= 4 # 2+2+2 from movement phases only
|
||||
|
||||
# France should be top priority due to higher count
|
||||
assert len(bot.priority_contacts) > 0
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
Loading…
Add table
Add a link
Reference in a new issue