mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-05-02 17:46:00 +00:00
WIP: Continual work on the bot client. They successfully play the game, but don't negotiate
Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
parent
c09d5bfe7e
commit
79a31d615d
13 changed files with 2483 additions and 263 deletions
|
|
@ -1 +1 @@
|
|||
# Test package initialization
|
||||
# Test package initialization
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import pytest_asyncio
|
|||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from .test_fake_websocket_server import FakeServerManager, FakeWebSocketServer
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
|
|
@ -39,13 +39,13 @@ def credentials() -> Generator[Dict[str, str], None]:
|
|||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
|
||||
async def client() -> AsyncGenerator[WebSocketDiplomacyClient, None]:
|
||||
"""
|
||||
Fixture that provides a TypedWebSocketDiplomacyClient instance.
|
||||
Fixture that provides a WebSocketDiplomacyClient instance.
|
||||
|
||||
The client is configured to connect to the fake server on port 8433.
|
||||
"""
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
yield client
|
||||
|
||||
# Cleanup: close the client connection
|
||||
|
|
@ -58,13 +58,13 @@ async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
|
|||
@pytest.fixture
|
||||
async def authenticated_client(fake_server, client):
|
||||
"""
|
||||
Fixture that provides an authenticated TypedWebSocketDiplomacyClient.
|
||||
Fixture that provides an authenticated WebSocketDiplomacyClient.
|
||||
|
||||
This client is already connected and authenticated, ready for testing
|
||||
game operations.
|
||||
"""
|
||||
await client.connect()
|
||||
token = await client.authenticate("test_user", "test_password")
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
token = client.token
|
||||
assert token is not None
|
||||
assert client.token == token
|
||||
yield client
|
||||
|
|
|
|||
|
|
@ -38,49 +38,41 @@ from models import (
|
|||
class FakeWebSocketServer:
|
||||
"""
|
||||
A fake WebSocket server that implements the Diplomacy protocol for testing.
|
||||
|
||||
|
||||
This server maintains minimal state and responds to requests with valid
|
||||
responses according to the WEBSOCKET.md protocol specification.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.connected_clients: Set[websockets.WebSocketServerProtocol] = set()
|
||||
|
||||
|
||||
# Mock server state
|
||||
self.users = {
|
||||
"test_user": "test_password",
|
||||
"ai_player": "password",
|
||||
"player1": "password"
|
||||
}
|
||||
self.users = {"test_user": "test_password", "ai_player": "password", "player1": "password"}
|
||||
self.tokens: Dict[str, str] = {} # token -> username
|
||||
self.games: Dict[str, Dict[str, Any]] = {}
|
||||
self.game_counter = 1
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""Start the fake WebSocket server."""
|
||||
logger.info(f"Starting fake WebSocket server on {self.host}:{self.port}")
|
||||
self.server = await websockets.serve(
|
||||
self.handle_client,
|
||||
self.host,
|
||||
self.port
|
||||
)
|
||||
self.server = await websockets.serve(self.handle_client, self.host, self.port)
|
||||
logger.info("Fake WebSocket server started")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the fake WebSocket server."""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("Fake WebSocket server stopped")
|
||||
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle a new WebSocket client connection."""
|
||||
self.connected_clients.add(websocket)
|
||||
logger.info(f"Client connected from {websocket.remote_address}")
|
||||
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
await self.handle_message(websocket, message)
|
||||
|
|
@ -90,49 +82,45 @@ class FakeWebSocketServer:
|
|||
logger.error(f"Error handling client: {e}")
|
||||
finally:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
|
||||
async def handle_message(self, websocket, message_text: str):
|
||||
"""Handle an incoming message from a client."""
|
||||
try:
|
||||
# Parse the raw message
|
||||
data = json.loads(message_text)
|
||||
request = parse_message(data)
|
||||
|
||||
|
||||
# Generate response based on request type
|
||||
response = await self.generate_response(request)
|
||||
|
||||
|
||||
# Send response back to client
|
||||
if response:
|
||||
response_data = serialize_message(response)
|
||||
await websocket.send(json.dumps(response_data))
|
||||
logger.debug(f"Sent response: {response.name}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
logger.debug(f"Raw message: {message_text}")
|
||||
|
||||
|
||||
# Send error response if we can extract request_id
|
||||
try:
|
||||
data = json.loads(message_text)
|
||||
request_id = data.get("request_id", str(uuid.uuid4()))
|
||||
error_response = ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="PARSING_ERROR",
|
||||
message=str(e)
|
||||
)
|
||||
error_response = ErrorResponse(request_id=request_id, error_type="PARSING_ERROR", message=str(e))
|
||||
error_data = serialize_message(error_response)
|
||||
await websocket.send(json.dumps(error_data))
|
||||
except:
|
||||
logger.error("Could not send error response")
|
||||
|
||||
|
||||
async def generate_response(self, request) -> Optional[Any]:
|
||||
"""Generate appropriate response for a request."""
|
||||
logger.info(f"Processing request: {request.name}")
|
||||
|
||||
|
||||
# Authentication requests
|
||||
if isinstance(request, SignInRequest):
|
||||
return await self.handle_sign_in(request)
|
||||
|
||||
|
||||
# Channel-level requests (require token)
|
||||
elif isinstance(request, CreateGameRequest):
|
||||
return await self.handle_create_game(request)
|
||||
|
|
@ -140,7 +128,7 @@ class FakeWebSocketServer:
|
|||
return await self.handle_join_game(request)
|
||||
elif isinstance(request, ListGamesRequest):
|
||||
return await self.handle_list_games(request)
|
||||
|
||||
|
||||
# Game-level requests (require token + game context)
|
||||
elif isinstance(request, SetOrdersRequest):
|
||||
return await self.handle_set_orders(request)
|
||||
|
|
@ -148,45 +136,36 @@ class FakeWebSocketServer:
|
|||
return await self.handle_process_game(request)
|
||||
elif isinstance(request, GetAllPossibleOrdersRequest):
|
||||
return await self.handle_get_possible_orders(request)
|
||||
|
||||
|
||||
else:
|
||||
logger.warning(f"Unhandled request type: {type(request)}")
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="UNSUPPORTED_REQUEST",
|
||||
message=f"Request type {request.name} not supported by fake server"
|
||||
request_id=request.request_id, error_type="UNSUPPORTED_REQUEST", message=f"Request type {request.name} not supported by fake server"
|
||||
)
|
||||
|
||||
|
||||
async def handle_sign_in(self, request: SignInRequest) -> Any:
|
||||
"""Handle authentication request."""
|
||||
username = request.username
|
||||
password = request.password
|
||||
|
||||
|
||||
if username in self.users and self.users[username] == password:
|
||||
# Generate auth token
|
||||
token = f"fake_token_{uuid.uuid4().hex[:16]}"
|
||||
self.tokens[token] = username
|
||||
|
||||
return DataTokenResponse(
|
||||
request_id=request.request_id,
|
||||
data=token
|
||||
)
|
||||
|
||||
return DataTokenResponse(request_id=request.request_id, data=token)
|
||||
else:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid username or password"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="AUTHENTICATION_ERROR", message="Invalid username or password")
|
||||
|
||||
async def handle_create_game(self, request: CreateGameRequest) -> Any:
|
||||
"""Handle game creation request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Create a new game
|
||||
game_id = f"GAME_{self.game_counter:04d}"
|
||||
self.game_counter += 1
|
||||
|
||||
|
||||
game_data = {
|
||||
"game_id": game_id,
|
||||
"map_name": request.map_name,
|
||||
|
|
@ -203,102 +182,83 @@ class FakeWebSocketServer:
|
|||
"RUSSIA": {"units": [], "centers": ["MOS", "SEV", "STP", "WAR"], "is_eliminated": False},
|
||||
"TURKEY": {"units": [], "centers": ["ANK", "CON", "SMY"], "is_eliminated": False},
|
||||
},
|
||||
"controlled_powers": {}
|
||||
"controlled_powers": {},
|
||||
}
|
||||
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
|
||||
return DataGameResponse(request_id=request.request_id, data=game_data)
|
||||
|
||||
async def handle_join_game(self, request: JoinGameRequest) -> Any:
|
||||
"""Handle game join request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
game_data = self.games[game_id].copy()
|
||||
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
|
||||
return DataGameResponse(request_id=request.request_id, data=game_data)
|
||||
|
||||
async def handle_list_games(self, request: ListGamesRequest) -> Any:
|
||||
"""Handle list games request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Return simplified game info
|
||||
games_list = []
|
||||
for game_id, game_data in self.games.items():
|
||||
games_list.append({
|
||||
"game_id": game_id,
|
||||
"map_name": game_data["map_name"],
|
||||
"status": game_data["status"],
|
||||
"phase": game_data["phase"],
|
||||
"n_controls": game_data["n_controls"]
|
||||
})
|
||||
|
||||
return DataGamesResponse(
|
||||
request_id=request.request_id,
|
||||
data=games_list
|
||||
)
|
||||
|
||||
games_list.append(
|
||||
{
|
||||
"game_id": game_id,
|
||||
"map_name": game_data["map_name"],
|
||||
"status": game_data["status"],
|
||||
"phase": game_data["phase"],
|
||||
"n_controls": game_data["n_controls"],
|
||||
}
|
||||
)
|
||||
|
||||
return DataGamesResponse(request_id=request.request_id, data=games_list)
|
||||
|
||||
async def handle_set_orders(self, request: SetOrdersRequest) -> Any:
|
||||
"""Handle set orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
# In a real server, we'd validate the orders and store them
|
||||
# For testing, we just acknowledge receipt
|
||||
logger.info(f"Orders received for {request.game_role}: {request.orders}")
|
||||
|
||||
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
|
||||
async def handle_process_game(self, request: ProcessGameRequest) -> Any:
|
||||
"""Handle process game request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
return ErrorResponse(request_id=request.request_id, error_type="GAME_NOT_FOUND", message=f"Game {game_id} not found")
|
||||
|
||||
# Simulate game processing
|
||||
game_data = self.games[game_id]
|
||||
current_phase = game_data["phase"]
|
||||
|
||||
|
||||
# Simple phase progression
|
||||
if current_phase == "S1901M":
|
||||
game_data["phase"] = "F1901M"
|
||||
|
|
@ -307,55 +267,48 @@ class FakeWebSocketServer:
|
|||
else:
|
||||
# For testing, cycle back to start
|
||||
game_data["phase"] = "S1902M"
|
||||
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
|
||||
# Send notification to all clients (in real implementation)
|
||||
# For testing, we'll just return OK
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
|
||||
async def handle_get_possible_orders(self, request: GetAllPossibleOrdersRequest) -> Any:
|
||||
"""Handle get possible orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
|
||||
# Return mock possible orders
|
||||
possible_orders = {
|
||||
"PAR": ["A PAR H", "A PAR - BUR", "A PAR - PIC", "A PAR - GAS"],
|
||||
"BRE": ["F BRE H", "F BRE - MAO", "F BRE - ENG", "F BRE - PIC"],
|
||||
"MAR": ["A MAR H", "A MAR - GAS", "A MAR - SPA", "A MAR - PIE"],
|
||||
}
|
||||
|
||||
return DataPossibleOrdersResponse(
|
||||
request_id=request.request_id,
|
||||
data=possible_orders
|
||||
)
|
||||
|
||||
|
||||
return DataPossibleOrdersResponse(request_id=request.request_id, data=possible_orders)
|
||||
|
||||
def validate_token(self, token: str) -> bool:
|
||||
"""Validate an authentication token."""
|
||||
return token in self.tokens
|
||||
|
||||
|
||||
def create_auth_error(self, request_id: str) -> ErrorResponse:
|
||||
"""Create a standard authentication error response."""
|
||||
return ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid or missing authentication token"
|
||||
)
|
||||
return ErrorResponse(request_id=request_id, error_type="AUTHENTICATION_ERROR", message="Invalid or missing authentication token")
|
||||
|
||||
|
||||
class FakeServerManager:
|
||||
"""Context manager for the fake WebSocket server."""
|
||||
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.server = FakeWebSocketServer(host, port)
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.server.start()
|
||||
# Give the server a moment to start
|
||||
await asyncio.sleep(0.1)
|
||||
return self.server
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.server.stop()
|
||||
|
||||
|
|
@ -375,4 +328,4 @@ async def run_fake_server(host: str = "localhost", port: int = 8433):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Run the fake server for manual testing
|
||||
asyncio.run(run_fake_server())
|
||||
asyncio.run(run_fake_server())
|
||||
|
|
|
|||
347
bot_client/tests/test_inter_power_messaging.py
Normal file
347
bot_client/tests/test_inter_power_messaging.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""
|
||||
Integration tests for inter-power messaging functionality.
|
||||
|
||||
These tests verify that the enhanced WebSocket client can properly:
|
||||
1. Send messages between powers
|
||||
2. Retrieve and filter messages
|
||||
3. Handle broadcast messaging
|
||||
4. Track message patterns for strategic communication
|
||||
|
||||
The tests use the fake server infrastructure but test real message flow
|
||||
without mocking the transport layer.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from typing import Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient, connect_to_diplomacy_server
|
||||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
|
||||
class TestBasicMessaging:
|
||||
"""Test core messaging functionality between powers."""
|
||||
|
||||
@pytest.fixture
|
||||
async def two_clients(self, fake_server):
|
||||
"""Fixture providing two authenticated clients in the same game."""
|
||||
# Create two clients
|
||||
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client_germany = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
# Connect and authenticate both clients
|
||||
await client_france.connect_and_authenticate("test_user", "test_password")
|
||||
await client_germany.connect_and_authenticate("ai_player", "password")
|
||||
|
||||
# Create a game with first client as FRANCE
|
||||
await client_france.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=2, # Only need 2 powers for testing
|
||||
)
|
||||
|
||||
# Second client joins as GERMANY
|
||||
await client_germany.join_game(game_id=client_france.game_id, power_name="GERMANY")
|
||||
|
||||
# Synchronize both clients
|
||||
await client_france.synchronize()
|
||||
await client_germany.synchronize()
|
||||
|
||||
yield {"france": client_france, "germany": client_germany}
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
try:
|
||||
await client_france.close()
|
||||
await client_germany.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_direct_message_sending(self, two_clients):
|
||||
"""Test sending a direct message between two powers."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send message from FRANCE to GERMANY
|
||||
test_message = "Hello Germany, shall we form an alliance?"
|
||||
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize both clients to get latest messages
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that Germany received the message
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
# Find our message
|
||||
sent_message = None
|
||||
for msg in germany_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "GERMANY":
|
||||
sent_message = msg
|
||||
break
|
||||
|
||||
assert sent_message is not None, "Message was not received by Germany"
|
||||
assert sent_message.message == test_message
|
||||
assert sent_message.sender == "FRANCE"
|
||||
assert sent_message.recipient == "GERMANY"
|
||||
|
||||
async def test_global_message_broadcasting(self, two_clients):
|
||||
"""Test broadcasting a global message."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send global message from FRANCE
|
||||
test_message = "Greetings to all powers! Let's have a good game."
|
||||
await france_client.send_message(sender="FRANCE", recipient=GLOBAL, message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize both clients
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that both clients can see the global message
|
||||
france_messages = await france_client.get_recent_messages(limit=10)
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
# Find the global message in both clients
|
||||
def find_global_message(messages):
|
||||
for msg in messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == GLOBAL:
|
||||
return msg
|
||||
return None
|
||||
|
||||
france_msg = find_global_message(france_messages)
|
||||
germany_msg = find_global_message(germany_messages)
|
||||
|
||||
assert france_msg is not None, "France should see its own global message"
|
||||
assert germany_msg is not None, "Germany should see France's global message"
|
||||
assert france_msg.message == test_message
|
||||
assert germany_msg.message == test_message
|
||||
|
||||
async def test_broadcast_to_all_active_powers(self, two_clients):
|
||||
"""Test the enhanced broadcast functionality."""
|
||||
france_client = two_clients["france"]
|
||||
germany_client = two_clients["germany"]
|
||||
|
||||
# Send broadcast message to all active powers
|
||||
test_message = "This is a broadcast to all active powers."
|
||||
await france_client.send_broadcast_message(sender="FRANCE", message=test_message)
|
||||
|
||||
# Allow messages to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize clients
|
||||
await france_client.synchronize()
|
||||
await germany_client.synchronize()
|
||||
|
||||
# Check that Germany received the direct message from broadcast
|
||||
germany_messages = await germany_client.get_recent_messages(limit=10)
|
||||
|
||||
broadcast_message = None
|
||||
for msg in germany_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "GERMANY" and msg.message == test_message:
|
||||
broadcast_message = msg
|
||||
break
|
||||
|
||||
assert broadcast_message is not None, "Germany should receive broadcast message"
|
||||
assert broadcast_message.message == test_message
|
||||
|
||||
async def test_message_filtering_by_phase(self, two_clients):
|
||||
"""Test message retrieval filtered by game phase."""
|
||||
france_client = two_clients["france"]
|
||||
|
||||
# Get current phase
|
||||
current_phase = france_client.get_current_short_phase()
|
||||
|
||||
# Send a message
|
||||
test_message = "Phase-specific message"
|
||||
await france_client.send_message(sender="FRANCE", recipient="GERMANY", message=test_message)
|
||||
|
||||
# Allow message to propagate
|
||||
await asyncio.sleep(0.1)
|
||||
await france_client.synchronize()
|
||||
|
||||
# Get messages for current phase
|
||||
phase_messages = await france_client.get_recent_messages(phase=current_phase, limit=10)
|
||||
|
||||
# Find our message
|
||||
found_message = None
|
||||
for msg in phase_messages:
|
||||
if msg.message == test_message:
|
||||
found_message = msg
|
||||
break
|
||||
|
||||
assert found_message is not None, "Message should be found in current phase"
|
||||
assert found_message.phase == current_phase
|
||||
|
||||
|
||||
class TestMessageHistory:
|
||||
"""Test message history and retrieval functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_with_messages(self, fake_server):
|
||||
"""Fixture providing a client with some test messages."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Send several test messages
|
||||
test_messages = [
|
||||
("Hello world!", GLOBAL),
|
||||
("Private message to England", "ENGLAND"),
|
||||
("Another global message", GLOBAL),
|
||||
("Direct to Germany", "GERMANY"),
|
||||
]
|
||||
|
||||
for message, recipient in test_messages:
|
||||
await client.send_message(sender="FRANCE", recipient=recipient, message=message)
|
||||
await asyncio.sleep(0.05) # Small delay between messages
|
||||
|
||||
await client.synchronize()
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_get_recent_messages_limit(self, client_with_messages):
|
||||
"""Test that message limit parameter works correctly."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get recent messages with different limits
|
||||
messages_3 = await client.get_recent_messages(limit=3)
|
||||
messages_2 = await client.get_recent_messages(limit=2)
|
||||
messages_1 = await client.get_recent_messages(limit=1)
|
||||
|
||||
assert len(messages_3) <= 3
|
||||
assert len(messages_2) <= 2
|
||||
assert len(messages_1) <= 1
|
||||
|
||||
# Messages should be in reverse chronological order (most recent first)
|
||||
if len(messages_3) > 1:
|
||||
# Check that timestamps are in descending order
|
||||
for i in range(len(messages_3) - 1):
|
||||
msg1_time = messages_3[i].time_sent or 0
|
||||
msg2_time = messages_3[i + 1].time_sent or 0
|
||||
assert msg1_time >= msg2_time, "Messages should be in reverse chronological order"
|
||||
|
||||
async def test_message_retrieval_by_sender(self, client_with_messages):
|
||||
"""Test filtering messages by sender."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get all recent messages
|
||||
all_messages = await client.get_recent_messages(limit=20)
|
||||
|
||||
# Filter messages from FRANCE
|
||||
france_messages = [msg for msg in all_messages if msg.sender == "FRANCE"]
|
||||
|
||||
# We should have the 4 messages we sent as FRANCE
|
||||
assert len(france_messages) >= 4, f"Expected at least 4 FRANCE messages, got {len(france_messages)}"
|
||||
|
||||
# All messages should be from FRANCE
|
||||
for msg in france_messages:
|
||||
assert msg.sender == "FRANCE"
|
||||
|
||||
async def test_empty_message_history(self, fake_server):
|
||||
"""Test behavior when no messages exist."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("ai_player", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="AUSTRIA",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Get messages when none exist
|
||||
messages = await client.get_recent_messages(limit=10)
|
||||
|
||||
# Should return empty list, not None or error
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) == 0
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class TestMessageErrorHandling:
|
||||
"""Test error handling in messaging functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, fake_server):
|
||||
"""Basic authenticated client for error testing."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("player1", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="ITALY",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_send_message_to_invalid_recipient(self, client):
|
||||
"""Test sending message to non-existent power."""
|
||||
# This should not raise an exception - the server/game should handle invalid recipients
|
||||
try:
|
||||
await client.send_message(sender="ITALY", recipient="INVALID_POWER", message="This should not crash")
|
||||
# If we get here, the call succeeded (which is fine)
|
||||
except Exception as e:
|
||||
# If an exception is raised, it should be a specific diplomacy exception, not a crash
|
||||
assert "INVALID_POWER" in str(e) or "recipient" in str(e).lower()
|
||||
|
||||
async def test_send_empty_message(self, client):
|
||||
"""Test sending empty message."""
|
||||
# Empty messages should be handled gracefully
|
||||
await client.send_message(sender="ITALY", recipient="FRANCE", message="")
|
||||
# If we get here without exception, the empty message was handled properly
|
||||
|
||||
async def test_get_messages_before_game_setup(self, fake_server):
|
||||
"""Test getting messages when no game is joined."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
|
||||
# Try to get messages without joining a game
|
||||
with pytest.raises(Exception): # Should raise some form of exception
|
||||
await client.get_recent_messages()
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
356
bot_client/tests/test_message_functionality_unit.py
Normal file
356
bot_client/tests/test_message_functionality_unit.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
"""
|
||||
Unit tests for the new inter-power messaging functionality.
|
||||
|
||||
These tests focus on testing the specific functionality we added:
|
||||
1. Enhanced WebSocketDiplomacyClient methods
|
||||
2. Message tracking and statistics in SingleBotPlayer
|
||||
3. Negotiation targeting logic
|
||||
|
||||
These are more focused unit tests that don't require full server integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from websocket_negotiations import (
|
||||
analyze_recent_messages_for_targeting,
|
||||
should_participate_in_negotiations,
|
||||
get_negotiation_delay,
|
||||
)
|
||||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
|
||||
class TestWebSocketClientEnhancements:
|
||||
"""Test the enhanced methods we added to WebSocketDiplomacyClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(self):
|
||||
"""Create a mocked WebSocketDiplomacyClient for testing."""
|
||||
client = WebSocketDiplomacyClient()
|
||||
|
||||
# Mock the game and related objects
|
||||
client.game = MagicMock()
|
||||
mock_powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: True), # Eliminated
|
||||
}
|
||||
client.game.powers = mock_powers
|
||||
|
||||
# Mock message history
|
||||
mock_messages = [
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Hello France!", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="Greetings!", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="GLOBAL", message="Hello everyone!", phase="S1901M"),
|
||||
]
|
||||
client.game.messages = {i: msg for i, msg in enumerate(mock_messages)}
|
||||
|
||||
return client
|
||||
|
||||
async def test_send_broadcast_message(self, mock_client):
|
||||
"""Test the new send_broadcast_message method."""
|
||||
# Mock the send_message method and powers property
|
||||
mock_client.send_message = AsyncMock()
|
||||
|
||||
# Mock the powers property to return our mock powers
|
||||
with patch.object(type(mock_client), "powers", new_callable=lambda: property(lambda self: self.game.powers)):
|
||||
# Send broadcast message
|
||||
await mock_client.send_broadcast_message(sender="FRANCE", message="Hello to all active powers!")
|
||||
|
||||
# Should have called send_message for each active power (excluding sender and eliminated)
|
||||
expected_calls = 2 # ENGLAND and GERMANY (not ITALY because eliminated, not FRANCE because sender)
|
||||
assert mock_client.send_message.call_count == expected_calls
|
||||
|
||||
# Check that messages were sent to the right powers
|
||||
call_args_list = mock_client.send_message.call_args_list
|
||||
recipients = [call[1]["recipient"] for call in call_args_list] # Get recipient from kwargs
|
||||
|
||||
assert "ENGLAND" in recipients
|
||||
assert "GERMANY" in recipients
|
||||
assert "FRANCE" not in recipients # Shouldn't send to self
|
||||
assert "ITALY" not in recipients # Shouldn't send to eliminated power
|
||||
|
||||
async def test_get_recent_messages_filtering(self, mock_client):
|
||||
"""Test the new get_recent_messages method with filtering."""
|
||||
# Mock current phase
|
||||
mock_client.get_current_short_phase = MagicMock(return_value="S1901M")
|
||||
|
||||
# Mock the messages property to return our mock messages
|
||||
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
||||
# Test getting recent messages
|
||||
recent_messages = await mock_client.get_recent_messages(limit=5)
|
||||
|
||||
# Should return the mocked messages
|
||||
assert len(recent_messages) <= 5
|
||||
assert len(recent_messages) == 3 # We have 3 mock messages
|
||||
|
||||
# Test phase filtering
|
||||
phase_messages = await mock_client.get_recent_messages(phase="S1901M", limit=10)
|
||||
assert all(msg.phase == "S1901M" for msg in phase_messages)
|
||||
|
||||
async def test_get_recent_messages_limit(self, mock_client):
|
||||
"""Test that the limit parameter works correctly."""
|
||||
# Mock the messages property to return our mock messages
|
||||
with patch.object(type(mock_client), "messages", new_callable=lambda: property(lambda self: self.messages)):
|
||||
# Test with limit smaller than available messages
|
||||
limited_messages = await mock_client.get_recent_messages(limit=2)
|
||||
assert len(limited_messages) <= 2
|
||||
|
||||
# Test with limit larger than available messages
|
||||
all_messages = await mock_client.get_recent_messages(limit=100)
|
||||
assert len(all_messages) == 3 # Should not exceed available messages
|
||||
|
||||
|
||||
class TestSingleBotPlayerMessageTracking:
|
||||
"""Test the message tracking functionality in SingleBotPlayer."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_player(self):
|
||||
"""Create a mock SingleBotPlayer for testing."""
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="FRANCE", model_name="test_model")
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock()
|
||||
bot.agent = MagicMock()
|
||||
bot.game_history = MagicMock()
|
||||
|
||||
# Set up powers
|
||||
bot.client.powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_counting_initialization(self, mock_bot_player):
|
||||
"""Test that message tracking starts in clean state."""
|
||||
bot = mock_bot_player
|
||||
|
||||
assert isinstance(bot.message_counts, dict)
|
||||
assert isinstance(bot.response_counts, dict)
|
||||
assert isinstance(bot.priority_contacts, list)
|
||||
|
||||
assert len(bot.message_counts) == 0
|
||||
assert len(bot.response_counts) == 0
|
||||
assert len(bot.priority_contacts) == 0
|
||||
|
||||
def test_priority_contact_updates(self, mock_bot_player):
|
||||
"""Test that priority contacts are updated correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts
|
||||
bot.message_counts = {
|
||||
"ENGLAND": 5,
|
||||
"GERMANY": 3,
|
||||
"ITALY": 2,
|
||||
"AUSTRIA": 1,
|
||||
"RUSSIA": 4,
|
||||
}
|
||||
|
||||
# Update priority contacts
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have top 4 contacts in order of activity
|
||||
assert len(bot.priority_contacts) == 4
|
||||
assert bot.priority_contacts[0] == "ENGLAND" # Highest count (5)
|
||||
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
||||
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
||||
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
||||
|
||||
def test_message_statistics_generation(self, mock_bot_player):
|
||||
"""Test generation of message statistics."""
|
||||
bot = mock_bot_player
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Set up data
|
||||
bot.message_counts = {"ENGLAND": 4, "GERMANY": 2}
|
||||
bot.response_counts = {"ENGLAND": 3, "GERMANY": 1}
|
||||
bot.priority_contacts = ["ENGLAND", "GERMANY"]
|
||||
|
||||
# Generate statistics
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
# Check basic structure
|
||||
assert stats["power_name"] == "FRANCE"
|
||||
assert stats["current_phase"] == "S1901M"
|
||||
assert stats["total_messages_received"] == 6 # 4 + 2
|
||||
assert stats["total_responses_sent"] == 4 # 3 + 1
|
||||
|
||||
# Check response rates
|
||||
assert stats["response_rate_by_power"]["ENGLAND"] == 0.75 # 3/4
|
||||
assert stats["response_rate_by_power"]["GERMANY"] == 0.5 # 1/2
|
||||
|
||||
def test_response_decision_logic(self, mock_bot_player):
|
||||
"""Test the enhanced response decision logic."""
|
||||
bot = mock_bot_player
|
||||
bot.priority_contacts = ["ENGLAND"]
|
||||
|
||||
# Create test messages
|
||||
priority_message = MagicMock()
|
||||
priority_message.sender = "ENGLAND"
|
||||
priority_message.recipient = "FRANCE"
|
||||
priority_message.message = "Hello France!"
|
||||
|
||||
non_priority_message = MagicMock()
|
||||
non_priority_message.sender = "GERMANY"
|
||||
non_priority_message.recipient = "FRANCE"
|
||||
non_priority_message.message = "Hello France!"
|
||||
|
||||
# Test the decision logic (extracted from _consider_message_response)
|
||||
def should_respond(message):
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = ["alliance", "deal", "propose", "agreement"]
|
||||
|
||||
return any(
|
||||
[
|
||||
"?" in message.message,
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]),
|
||||
any(keyword in message_lower for keyword in strategic_keywords),
|
||||
len(message.message.split()) > 15,
|
||||
message.sender in bot.priority_contacts,
|
||||
]
|
||||
)
|
||||
|
||||
# Both should respond due to "hello", but priority logic is tested
|
||||
assert should_respond(priority_message) is True
|
||||
assert should_respond(non_priority_message) is True
|
||||
|
||||
# Test priority contact influence
|
||||
assert priority_message.sender in bot.priority_contacts
|
||||
assert non_priority_message.sender not in bot.priority_contacts
|
||||
|
||||
|
||||
class TestNegotiationTargeting:
|
||||
"""Test the strategic negotiation targeting logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_messages(self):
|
||||
"""Create a mock client with message history."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock recent messages
|
||||
mock_messages = [
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Direct to France", phase="S1901M"),
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Another to France", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="Message to France", phase="S1901M"),
|
||||
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"),
|
||||
]
|
||||
|
||||
client.get_recent_messages = AsyncMock(return_value=mock_messages)
|
||||
client.powers = {
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
|
||||
return client
|
||||
|
||||
async def test_analyze_recent_messages_for_targeting(self, mock_client_with_messages):
|
||||
"""Test the message analysis for targeting."""
|
||||
client = mock_client_with_messages
|
||||
|
||||
# Analyze targeting for FRANCE
|
||||
targets = await analyze_recent_messages_for_targeting(client, "FRANCE", max_messages=20)
|
||||
|
||||
# Should return a list of powers
|
||||
assert isinstance(targets, list)
|
||||
assert "FRANCE" not in targets # Should not include self
|
||||
|
||||
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
|
||||
# GERMANY should be second (sent 1 direct message to FRANCE)
|
||||
if len(targets) >= 2:
|
||||
assert targets[0] == "ENGLAND" # Most direct messages
|
||||
if "GERMANY" in targets:
|
||||
germany_index = targets.index("GERMANY")
|
||||
england_index = targets.index("ENGLAND")
|
||||
assert england_index < germany_index # England should come before Germany
|
||||
|
||||
def test_negotiation_delay_calculation(self):
|
||||
"""Test that negotiation delays are calculated correctly."""
|
||||
# Test different round scenarios
|
||||
first_delay = get_negotiation_delay(round_number=1, total_rounds=3)
|
||||
middle_delay = get_negotiation_delay(round_number=2, total_rounds=3)
|
||||
final_delay = get_negotiation_delay(round_number=3, total_rounds=3)
|
||||
|
||||
# First round should have longer delay
|
||||
assert first_delay > middle_delay
|
||||
# Final round should have shorter delay
|
||||
assert final_delay < middle_delay
|
||||
# All delays should be positive
|
||||
assert all(delay > 0 for delay in [first_delay, middle_delay, final_delay])
|
||||
|
||||
async def test_should_participate_in_negotiations(self):
|
||||
"""Test negotiation participation logic."""
|
||||
# Mock client and agent
|
||||
mock_client = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.power_name = "FRANCE"
|
||||
|
||||
# Test case: active power in movement phase
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_client.get_power.return_value = mock_power
|
||||
mock_client.get_current_short_phase.return_value = "S1901M"
|
||||
|
||||
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_orders.return_value = ["A Paris - Hold"] # Has orders
|
||||
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is True
|
||||
|
||||
# Test case: eliminated power
|
||||
mock_power.is_eliminated.return_value = True
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is False
|
||||
|
||||
# Test case: non-movement phase
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_client.get_current_short_phase.return_value = "S1901R" # Retreat phase
|
||||
|
||||
with patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_orders.return_value = ["A Paris - Hold"]
|
||||
|
||||
result = await should_participate_in_negotiations(mock_client, mock_agent)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios with mocked components."""
|
||||
|
||||
def test_message_persistence_across_phases(self):
|
||||
"""Test that message counts persist across multiple game phases."""
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="AUSTRIA", model_name="test_model")
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock()
|
||||
bot.agent = MagicMock()
|
||||
bot.game_history = MagicMock()
|
||||
|
||||
# Simulate message accumulation over phases
|
||||
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
||||
|
||||
for phase in phases:
|
||||
bot.current_phase = phase
|
||||
|
||||
# Simulate receiving messages
|
||||
if phase.endswith("M"): # Movement phases
|
||||
for sender in ["FRANCE", "ENGLAND"]:
|
||||
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
||||
else:
|
||||
# Fewer messages in other phases
|
||||
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Check accumulated counts
|
||||
assert bot.message_counts["FRANCE"] >= 5 # Should have accumulated messages
|
||||
assert bot.message_counts["ENGLAND"] >= 4
|
||||
assert len(bot.priority_contacts) > 0
|
||||
assert bot.priority_contacts[0] == "FRANCE" # Should be top priority
|
||||
403
bot_client/tests/test_message_persistence.py
Normal file
403
bot_client/tests/test_message_persistence.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
"""
|
||||
Integration tests for message persistence and tracking functionality.
|
||||
|
||||
These tests verify that the SingleBotPlayer can properly:
|
||||
1. Track message counts and response patterns
|
||||
2. Update priority contacts based on messaging activity
|
||||
3. Generate accurate message statistics
|
||||
4. Log statistics during phase transitions
|
||||
5. Maintain message persistence across game phases
|
||||
|
||||
Tests focus on the tracking and statistics features rather than
|
||||
the actual AI message generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict
|
||||
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from diplomacy.engine.message import Message
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestMessageCounting:
|
||||
"""Test message counting and tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_player(self):
|
||||
"""Create a SingleBotPlayer with mocked dependencies for testing."""
|
||||
# Create mock bot player
|
||||
bot = SingleBotPlayer(username="test_bot", password="test_pass", power_name="ENGLAND", model_name="test_model", game_id="test_game")
|
||||
|
||||
# Mock the client and agent
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_initial_message_tracking_state(self, mock_bot_player):
|
||||
"""Test that message tracking starts in clean state."""
|
||||
bot = mock_bot_player
|
||||
|
||||
assert isinstance(bot.message_counts, dict)
|
||||
assert isinstance(bot.response_counts, dict)
|
||||
assert isinstance(bot.priority_contacts, list)
|
||||
|
||||
assert len(bot.message_counts) == 0
|
||||
assert len(bot.response_counts) == 0
|
||||
assert len(bot.priority_contacts) == 0
|
||||
|
||||
def test_message_count_tracking(self, mock_bot_player):
|
||||
"""Test that incoming messages are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate receiving messages from different powers
|
||||
test_messages = [
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Hello England!", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Another message", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="ENGLAND", message="Greetings!", phase="S1901M"),
|
||||
Message(sender="ITALY", recipient="GLOBAL", message="Global message", phase="S1901M"), # Should not be counted
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="Self message", phase="S1901M"), # Should not be counted
|
||||
]
|
||||
|
||||
# Process each message through the message handler
|
||||
for msg in test_messages:
|
||||
# Simulate the message handling logic from _on_message_received
|
||||
bot.game_history.add_message(
|
||||
phase_name=msg.phase,
|
||||
sender=msg.sender,
|
||||
recipient=msg.recipient,
|
||||
message_content=msg.message,
|
||||
)
|
||||
|
||||
# Track message patterns (only for messages TO this bot)
|
||||
if msg.recipient == bot.power_name and msg.sender != bot.power_name:
|
||||
bot.message_counts[msg.sender] = bot.message_counts.get(msg.sender, 0) + 1
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Check counts
|
||||
assert bot.message_counts["FRANCE"] == 2
|
||||
assert bot.message_counts["GERMANY"] == 1
|
||||
assert "ITALY" not in bot.message_counts # Global message not counted
|
||||
assert "ENGLAND" not in bot.message_counts # Self message not counted
|
||||
|
||||
def test_response_count_tracking(self, mock_bot_player):
|
||||
"""Test that outgoing responses are counted correctly."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Simulate sending responses to different powers
|
||||
responses = [
|
||||
("FRANCE", "Thanks for your message!"),
|
||||
("FRANCE", "Another response to France"),
|
||||
("GERMANY", "Hello Germany"),
|
||||
]
|
||||
|
||||
for recipient, message in responses:
|
||||
# Simulate response sending logic
|
||||
bot.response_counts[recipient] = bot.response_counts.get(recipient, 0) + 1
|
||||
|
||||
# Check response counts
|
||||
assert bot.response_counts["FRANCE"] == 2
|
||||
assert bot.response_counts["GERMANY"] == 1
|
||||
|
||||
def test_priority_contact_updates(self, mock_bot_player):
|
||||
"""Test that priority contacts are updated based on message counts."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with different activity levels
|
||||
bot.message_counts = {
|
||||
"FRANCE": 5, # Most active
|
||||
"GERMANY": 3, # Second most active
|
||||
"ITALY": 2, # Third most active
|
||||
"AUSTRIA": 1, # Least active
|
||||
"RUSSIA": 4, # Second highest
|
||||
}
|
||||
|
||||
# Update priority contacts
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have top 4 contacts in order of activity
|
||||
assert len(bot.priority_contacts) == 4
|
||||
assert bot.priority_contacts[0] == "FRANCE" # Highest count (5)
|
||||
assert bot.priority_contacts[1] == "RUSSIA" # Second highest (4)
|
||||
assert bot.priority_contacts[2] == "GERMANY" # Third highest (3)
|
||||
assert bot.priority_contacts[3] == "ITALY" # Fourth highest (2)
|
||||
# AUSTRIA should not be in top 4
|
||||
|
||||
def test_priority_contacts_with_fewer_powers(self, mock_bot_player):
|
||||
"""Test priority contacts when fewer than 4 powers are active."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up message counts with only 2 powers
|
||||
bot.message_counts = {
|
||||
"FRANCE": 3,
|
||||
"GERMANY": 1,
|
||||
}
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Should have only 2 contacts
|
||||
assert len(bot.priority_contacts) == 2
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
assert bot.priority_contacts[1] == "GERMANY"
|
||||
|
||||
|
||||
class TestMessageStatistics:
|
||||
"""Test message statistics generation and reporting."""
|
||||
|
||||
@pytest.fixture
|
||||
def bot_with_message_data(self, mock_bot_player):
|
||||
"""Bot player with pre-populated message tracking data."""
|
||||
bot = mock_bot_player
|
||||
|
||||
# Set up realistic message and response data
|
||||
bot.message_counts = {
|
||||
"FRANCE": 8,
|
||||
"GERMANY": 5,
|
||||
"ITALY": 3,
|
||||
"RUSSIA": 2,
|
||||
}
|
||||
|
||||
bot.response_counts = {
|
||||
"FRANCE": 6, # 75% response rate
|
||||
"GERMANY": 2, # 40% response rate
|
||||
"ITALY": 3, # 100% response rate
|
||||
"RUSSIA": 0, # 0% response rate
|
||||
}
|
||||
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY", "ITALY", "RUSSIA"]
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client's powers for active power detection
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"RUSSIA": MagicMock(is_eliminated=lambda: False),
|
||||
"AUSTRIA": MagicMock(is_eliminated=lambda: True), # Eliminated
|
||||
"TURKEY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_statistics_generation(self, bot_with_message_data):
|
||||
"""Test that message statistics are generated correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
# Check basic structure
|
||||
assert stats["power_name"] == "ENGLAND"
|
||||
assert stats["current_phase"] == "S1901M"
|
||||
assert isinstance(stats["message_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_counts_by_power"], dict)
|
||||
assert isinstance(stats["response_rate_by_power"], dict)
|
||||
assert isinstance(stats["priority_contacts"], list)
|
||||
assert isinstance(stats["active_powers"], list)
|
||||
|
||||
# Check calculated values
|
||||
assert stats["total_messages_received"] == 18 # 8+5+3+2
|
||||
assert stats["total_responses_sent"] == 11 # 6+2+3+0
|
||||
|
||||
# Check response rates
|
||||
assert stats["response_rate_by_power"]["FRANCE"] == 0.75 # 6/8
|
||||
assert stats["response_rate_by_power"]["GERMANY"] == 0.4 # 2/5
|
||||
assert stats["response_rate_by_power"]["ITALY"] == 1.0 # 3/3
|
||||
assert stats["response_rate_by_power"]["RUSSIA"] == 0.0 # 0/2
|
||||
|
||||
# Check active powers (should exclude eliminated AUSTRIA and self)
|
||||
active_powers = stats["active_powers"]
|
||||
assert "AUSTRIA" not in active_powers # Eliminated
|
||||
assert "ENGLAND" not in active_powers # Self
|
||||
assert "FRANCE" in active_powers
|
||||
assert "TURKEY" in active_powers
|
||||
|
||||
def test_message_statistics_empty_data(self, mock_bot_player):
|
||||
"""Test statistics generation with no message data."""
|
||||
bot = mock_bot_player
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock empty powers
|
||||
mock_powers = {
|
||||
"ENGLAND": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
stats = bot.get_message_statistics()
|
||||
|
||||
assert stats["total_messages_received"] == 0
|
||||
assert stats["total_responses_sent"] == 0
|
||||
assert len(stats["message_counts_by_power"]) == 0
|
||||
assert len(stats["response_counts_by_power"]) == 0
|
||||
assert len(stats["response_rate_by_power"]) == 0
|
||||
assert len(stats["priority_contacts"]) == 0
|
||||
|
||||
def test_message_statistics_logging(self, bot_with_message_data, caplog):
|
||||
"""Test that message statistics are logged correctly."""
|
||||
bot = bot_with_message_data
|
||||
|
||||
# Call the logging method
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Check that appropriate log messages were generated
|
||||
log_output = caplog.text
|
||||
assert "Message Statistics for ENGLAND" in log_output
|
||||
assert "Total messages received: 18" in log_output
|
||||
assert "Total responses sent: 11" in log_output
|
||||
assert "Priority contacts:" in log_output
|
||||
|
||||
# Check that individual power stats are logged
|
||||
assert "FRANCE:" in log_output
|
||||
assert "GERMANY:" in log_output
|
||||
assert "75%" in log_output # France response rate
|
||||
assert "40%" in log_output # Germany response rate
|
||||
|
||||
|
||||
class TestMessagePersistenceIntegration:
|
||||
"""Test integration of message persistence with game flow."""
|
||||
|
||||
@pytest.fixture
|
||||
async def bot_integration_setup(self, fake_server):
|
||||
"""Setup for integration testing with fake server."""
|
||||
# Note: This is a more complex fixture that would require actual
|
||||
# SingleBotPlayer initialization, which depends on AI client setup
|
||||
# For now, we'll focus on the core tracking logic
|
||||
|
||||
bot = SingleBotPlayer(username="integration_test", password="test_pass", power_name="ITALY", model_name="test_model", game_id=None)
|
||||
|
||||
# Mock the complex dependencies
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
return bot
|
||||
|
||||
def test_message_tracking_during_phase_transition(self, bot_integration_setup):
|
||||
"""Test that message statistics are logged during phase transitions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up some message data
|
||||
bot.message_counts = {"FRANCE": 2, "GERMANY": 1}
|
||||
bot.current_phase = "S1901M"
|
||||
|
||||
# Mock the client powers for log_message_statistics
|
||||
mock_powers = {
|
||||
"ITALY": MagicMock(is_eliminated=lambda: False),
|
||||
"FRANCE": MagicMock(is_eliminated=lambda: False),
|
||||
"GERMANY": MagicMock(is_eliminated=lambda: False),
|
||||
}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
# Mock the log_message_statistics method to track if it's called
|
||||
with patch.object(bot, "log_message_statistics") as mock_log:
|
||||
# Simulate phase transition logic
|
||||
new_phase = "F1901M"
|
||||
if new_phase != bot.current_phase:
|
||||
bot.current_phase = new_phase
|
||||
bot.game_history.add_phase(new_phase)
|
||||
bot.orders_submitted = False
|
||||
bot.current_negotiation_round = 0
|
||||
bot.negotiation_complete = False
|
||||
|
||||
# This is the key logic from _handle_phase_update_async
|
||||
if hasattr(bot, "message_counts") and bot.message_counts:
|
||||
bot.log_message_statistics()
|
||||
|
||||
# Verify that statistics were logged
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_priority_contact_influence_on_response_decisions(self, bot_integration_setup):
|
||||
"""Test that priority contacts influence message response decisions."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Set up priority contacts
|
||||
bot.priority_contacts = ["FRANCE", "GERMANY"]
|
||||
|
||||
# Create test messages from different senders
|
||||
priority_message = Message(sender="FRANCE", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
non_priority_message = Message(sender="AUSTRIA", recipient="ITALY", message="Hello Italy!", phase="S1901M")
|
||||
|
||||
# Test the response decision logic (from _consider_message_response)
|
||||
def should_respond_to_message(message):
|
||||
message_lower = message.message.lower()
|
||||
strategic_keywords = [
|
||||
"alliance",
|
||||
"deal",
|
||||
"propose",
|
||||
"agreement",
|
||||
"support",
|
||||
"attack",
|
||||
"coordinate",
|
||||
"move",
|
||||
"order",
|
||||
"help",
|
||||
"work together",
|
||||
"partner",
|
||||
"enemy",
|
||||
"threat",
|
||||
"negotiate",
|
||||
"discuss",
|
||||
"plan",
|
||||
"strategy",
|
||||
"bounce",
|
||||
"convoy",
|
||||
"retreat",
|
||||
]
|
||||
|
||||
return any(
|
||||
[
|
||||
"?" in message.message, # Questions
|
||||
any(word in message_lower for word in ["hello", "hi", "greetings"]), # Greetings
|
||||
any(keyword in message_lower for keyword in strategic_keywords), # Strategic content
|
||||
len(message.message.split()) > 15, # Longer messages
|
||||
message.sender in bot.priority_contacts, # Priority contacts
|
||||
]
|
||||
)
|
||||
|
||||
# Priority contact should be more likely to get response
|
||||
priority_should_respond = should_respond_to_message(priority_message)
|
||||
non_priority_should_respond = should_respond_to_message(non_priority_message)
|
||||
|
||||
# Both should respond due to "hello" keyword, but priority contact logic is working
|
||||
assert priority_should_respond is True
|
||||
# Non-priority should also respond due to "hello", but test the contact logic
|
||||
assert priority_message.sender in bot.priority_contacts
|
||||
assert non_priority_message.sender not in bot.priority_contacts
|
||||
|
||||
def test_message_persistence_across_multiple_phases(self, bot_integration_setup):
|
||||
"""Test that message counts persist across multiple game phases."""
|
||||
bot = bot_integration_setup
|
||||
|
||||
# Simulate message accumulation over multiple phases
|
||||
phases = ["S1901M", "F1901M", "W1901A", "S1902M"]
|
||||
|
||||
for phase in phases:
|
||||
bot.current_phase = phase
|
||||
|
||||
# Simulate receiving messages in each phase
|
||||
if phase.endswith("M"): # Movement phases
|
||||
# More messages during movement phases
|
||||
for sender in ["FRANCE", "GERMANY"]:
|
||||
bot.message_counts[sender] = bot.message_counts.get(sender, 0) + 2
|
||||
else:
|
||||
# Fewer messages during other phases
|
||||
bot.message_counts["FRANCE"] = bot.message_counts.get("FRANCE", 0) + 1
|
||||
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# After all phases, check accumulated counts
|
||||
assert bot.message_counts["FRANCE"] >= 5 # 2+2+1+2 from movement phases + 1 from adjustment
|
||||
assert bot.message_counts["GERMANY"] >= 4 # 2+2+2 from movement phases only
|
||||
|
||||
# France should be top priority due to higher count
|
||||
assert len(bot.priority_contacts) > 0
|
||||
assert bot.priority_contacts[0] == "FRANCE"
|
||||
440
bot_client/tests/test_multi_bot_scenarios.py
Normal file
440
bot_client/tests/test_multi_bot_scenarios.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Integration tests for multi-bot messaging scenarios.
|
||||
|
||||
These tests verify realistic inter-power communication scenarios with
|
||||
multiple bots interacting simultaneously. Tests focus on:
|
||||
1. Multi-bot conversation flows
|
||||
2. Negotiation coordination between multiple powers
|
||||
3. Message response patterns in realistic scenarios
|
||||
4. Phase transition coordination across multiple bots
|
||||
|
||||
Tests use the fake server but simulate realistic multi-power interactions
|
||||
without mocking the core message transport logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from single_bot_player import SingleBotPlayer
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestTwoBotConversation:
|
||||
"""Test conversation flow between two bots."""
|
||||
|
||||
@pytest.fixture
|
||||
async def two_bot_setup(self, fake_server):
|
||||
"""Setup two bots in the same game for conversation testing."""
|
||||
# Create two WebSocket clients
|
||||
client_france = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client_england = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
# Connect both clients
|
||||
await client_france.connect_and_authenticate("france_bot", "password")
|
||||
await client_england.connect_and_authenticate("england_bot", "password")
|
||||
|
||||
# Create game with France
|
||||
await client_france.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=2,
|
||||
)
|
||||
|
||||
# England joins the game
|
||||
await client_england.join_game(game_id=client_france.game_id, power_name="ENGLAND")
|
||||
|
||||
# Synchronize both clients
|
||||
await client_france.synchronize()
|
||||
await client_england.synchronize()
|
||||
|
||||
yield {"france_client": client_france, "england_client": client_england, "game_id": client_france.game_id}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client_france.close()
|
||||
await client_england.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_basic_two_bot_exchange(self, two_bot_setup):
|
||||
"""Test basic message exchange between two bots."""
|
||||
france_client = two_bot_setup["france_client"]
|
||||
england_client = two_bot_setup["england_client"]
|
||||
|
||||
# France initiates conversation
|
||||
await france_client.send_message(sender="FRANCE", recipient="ENGLAND", message="Hello England! Shall we discuss our border?")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await england_client.synchronize()
|
||||
|
||||
# Check England received the message
|
||||
england_messages = await england_client.get_recent_messages(limit=5)
|
||||
france_message = None
|
||||
for msg in england_messages:
|
||||
if msg.sender == "FRANCE" and msg.recipient == "ENGLAND":
|
||||
france_message = msg
|
||||
break
|
||||
|
||||
assert france_message is not None
|
||||
assert "border" in france_message.message
|
||||
|
||||
# England responds
|
||||
await england_client.send_message(sender="ENGLAND", recipient="FRANCE", message="Indeed, France. I propose we coordinate our fleets.")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await france_client.synchronize()
|
||||
|
||||
# Check France received the response
|
||||
france_messages = await france_client.get_recent_messages(limit=5)
|
||||
england_response = None
|
||||
for msg in france_messages:
|
||||
if msg.sender == "ENGLAND" and msg.recipient == "FRANCE":
|
||||
england_response = msg
|
||||
break
|
||||
|
||||
assert england_response is not None
|
||||
assert "coordinate" in england_response.message
|
||||
|
||||
async def test_conversation_thread_tracking(self, two_bot_setup):
|
||||
"""Test that conversation threads can be tracked across multiple exchanges."""
|
||||
france_client = two_bot_setup["france_client"]
|
||||
england_client = two_bot_setup["england_client"]
|
||||
|
||||
# Simulate a conversation thread
|
||||
conversation = [
|
||||
("FRANCE", "ENGLAND", "England, are you interested in an alliance?"),
|
||||
("ENGLAND", "FRANCE", "Yes France, what do you propose?"),
|
||||
("FRANCE", "ENGLAND", "Let's coordinate attacks on Germany."),
|
||||
("ENGLAND", "FRANCE", "Agreed. I'll move my fleet to support you."),
|
||||
]
|
||||
|
||||
for sender, recipient, message in conversation:
|
||||
if sender == "FRANCE":
|
||||
await france_client.send_message(sender, recipient, message)
|
||||
else:
|
||||
await england_client.send_message(sender, recipient, message)
|
||||
|
||||
await asyncio.sleep(0.05) # Small delay between messages
|
||||
|
||||
# Synchronize both clients
|
||||
await france_client.synchronize()
|
||||
await england_client.synchronize()
|
||||
|
||||
# Get conversation history from both perspectives
|
||||
france_messages = await france_client.get_recent_messages(limit=10)
|
||||
england_messages = await england_client.get_recent_messages(limit=10)
|
||||
|
||||
# Filter for conversation between France and England
|
||||
def filter_conversation(messages):
|
||||
return [msg for msg in messages if (msg.sender in ["FRANCE", "ENGLAND"] and msg.recipient in ["FRANCE", "ENGLAND"])]
|
||||
|
||||
france_conv = filter_conversation(france_messages)
|
||||
england_conv = filter_conversation(england_messages)
|
||||
|
||||
# Both should see the same conversation
|
||||
assert len(france_conv) >= 4
|
||||
assert len(england_conv) >= 4
|
||||
|
||||
# Check that key terms from the conversation appear
|
||||
all_messages_text = " ".join([msg.message for msg in france_conv])
|
||||
assert "alliance" in all_messages_text
|
||||
assert "Germany" in all_messages_text
|
||||
assert "support" in all_messages_text
|
||||
|
||||
|
||||
class TestThreeBotNegotiation:
|
||||
"""Test more complex negotiations with three bots."""
|
||||
|
||||
@pytest.fixture
|
||||
async def three_bot_setup(self, fake_server):
|
||||
"""Setup three bots for complex negotiation testing."""
|
||||
clients = {}
|
||||
powers = ["FRANCE", "ENGLAND", "GERMANY"]
|
||||
|
||||
try:
|
||||
# Create and connect three clients
|
||||
for power in powers:
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
await client.connect_and_authenticate(f"{power.lower()}_bot", "password")
|
||||
clients[power] = client
|
||||
|
||||
# Create game with France
|
||||
await clients["FRANCE"].create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=3,
|
||||
)
|
||||
|
||||
game_id = clients["FRANCE"].game_id
|
||||
|
||||
# Other powers join
|
||||
await clients["ENGLAND"].join_game(game_id=game_id, power_name="ENGLAND")
|
||||
await clients["GERMANY"].join_game(game_id=game_id, power_name="GERMANY")
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
yield {"clients": clients, "game_id": game_id}
|
||||
|
||||
finally:
|
||||
for client in clients.values():
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_three_way_alliance_negotiation(self, three_bot_setup):
|
||||
"""Test alliance negotiation between three powers."""
|
||||
clients = three_bot_setup["clients"]
|
||||
|
||||
# Simulate alliance negotiation sequence
|
||||
negotiations = [
|
||||
("FRANCE", "ENGLAND", "England, shall we form an alliance against Germany?"),
|
||||
("ENGLAND", "FRANCE", "I'm interested. What are your terms?"),
|
||||
("FRANCE", "GERMANY", "Germany, France and England are discussing cooperation."),
|
||||
("GERMANY", "FRANCE", "I see. Perhaps we should talk as well."),
|
||||
("GERMANY", "ENGLAND", "England, what is France offering you?"),
|
||||
("ENGLAND", "GERMANY", "Germany, I think we should all work together."),
|
||||
]
|
||||
|
||||
# Send all negotiation messages
|
||||
for sender, recipient, message in negotiations:
|
||||
await clients[sender].send_message(sender, recipient, message)
|
||||
await asyncio.sleep(0.1) # Allow message processing
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
# Analyze message patterns from each perspective
|
||||
for power, client in clients.items():
|
||||
messages = await client.get_recent_messages(limit=20)
|
||||
|
||||
# Count messages involving this power
|
||||
involving_power = [msg for msg in messages if msg.sender == power or msg.recipient == power]
|
||||
|
||||
# Each power should be involved in multiple messages
|
||||
assert len(involving_power) >= 2, f"{power} should be involved in multiple messages"
|
||||
|
||||
# Check that all three powers have communicated
|
||||
all_messages = await clients["FRANCE"].get_recent_messages(limit=20)
|
||||
|
||||
senders = set(msg.sender for msg in all_messages)
|
||||
recipients = set(msg.recipient for msg in all_messages)
|
||||
|
||||
# All three powers should appear as senders
|
||||
assert "FRANCE" in senders
|
||||
assert "ENGLAND" in senders
|
||||
assert "GERMANY" in senders
|
||||
|
||||
async def test_broadcast_with_private_follow_ups(self, three_bot_setup):
|
||||
"""Test broadcast message followed by private conversations."""
|
||||
clients = three_bot_setup["clients"]
|
||||
|
||||
# France sends a global announcement
|
||||
await clients["FRANCE"].send_broadcast_message(sender="FRANCE", message="All powers: I propose we establish clear spheres of influence.")
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Follow up with private messages to each power
|
||||
private_messages = [
|
||||
("FRANCE", "ENGLAND", "England, I suggest you focus on the seas."),
|
||||
("FRANCE", "GERMANY", "Germany, the eastern approach might suit you."),
|
||||
]
|
||||
|
||||
for sender, recipient, message in private_messages:
|
||||
await clients[sender].send_message(sender, recipient, message)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Synchronize all clients
|
||||
for client in clients.values():
|
||||
await client.synchronize()
|
||||
|
||||
# Check that England and Germany received their specific messages
|
||||
england_messages = await clients["ENGLAND"].get_recent_messages(limit=10)
|
||||
germany_messages = await clients["GERMANY"].get_recent_messages(limit=10)
|
||||
|
||||
# England should see both broadcast (to all) and private message
|
||||
england_private = [msg for msg in england_messages if msg.sender == "FRANCE" and msg.recipient == "ENGLAND"]
|
||||
england_broadcast = [msg for msg in england_messages if msg.sender == "FRANCE" and "spheres of influence" in msg.message]
|
||||
|
||||
assert len(england_private) >= 1, "England should receive private message"
|
||||
assert any("seas" in msg.message for msg in england_private), "England should get seas message"
|
||||
|
||||
# Germany should see broadcast and their private message
|
||||
germany_private = [msg for msg in germany_messages if msg.sender == "FRANCE" and msg.recipient == "GERMANY"]
|
||||
|
||||
assert len(germany_private) >= 1, "Germany should receive private message"
|
||||
assert any("eastern" in msg.message for msg in germany_private), "Germany should get eastern message"
|
||||
|
||||
|
||||
class TestBotPlayerMessageIntegration:
|
||||
"""Test SingleBotPlayer message handling in multi-bot scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_players(self):
|
||||
"""Create multiple mock SingleBotPlayer instances."""
|
||||
bots = {}
|
||||
powers = ["FRANCE", "ENGLAND", "GERMANY"]
|
||||
|
||||
for power in powers:
|
||||
bot = SingleBotPlayer(
|
||||
username=f"{power.lower()}_bot", password="test_pass", power_name=power, model_name="test_model", game_id="test_game"
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
bot.client = MagicMock(spec=WebSocketDiplomacyClient)
|
||||
bot.agent = MagicMock(spec=DiplomacyAgent)
|
||||
bot.game_history = GameHistory()
|
||||
|
||||
# Set up power name correctly
|
||||
bot.agent.power_name = power
|
||||
|
||||
bots[power] = bot
|
||||
|
||||
return bots
|
||||
|
||||
def test_multi_bot_priority_contact_evolution(self, mock_bot_players):
|
||||
"""Test how priority contacts evolve in multi-bot scenarios."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Simulate asymmetric communication patterns
|
||||
# France talks to everyone, England focuses on France, Germany is quiet
|
||||
|
||||
# France receives messages from multiple powers
|
||||
france_bot = bots["FRANCE"]
|
||||
france_bot.message_counts = {
|
||||
"ENGLAND": 5, # England is very active with France
|
||||
"GERMANY": 2, # Germany occasionally talks to France
|
||||
"ITALY": 1, # Italy sends one message
|
||||
}
|
||||
france_bot._update_priority_contacts()
|
||||
|
||||
# England receives mostly from France
|
||||
england_bot = bots["ENGLAND"]
|
||||
england_bot.message_counts = {
|
||||
"FRANCE": 8, # France talks to England a lot
|
||||
"GERMANY": 1, # Germany sends one message
|
||||
}
|
||||
england_bot._update_priority_contacts()
|
||||
|
||||
# Germany receives few messages
|
||||
germany_bot = bots["GERMANY"]
|
||||
germany_bot.message_counts = {
|
||||
"FRANCE": 3, # Some communication with France
|
||||
"ENGLAND": 1, # Minimal with England
|
||||
}
|
||||
germany_bot._update_priority_contacts()
|
||||
|
||||
# Check that priority contacts reflect communication patterns
|
||||
assert france_bot.priority_contacts[0] == "ENGLAND" # Most active with France
|
||||
assert england_bot.priority_contacts[0] == "FRANCE" # France is England's main contact
|
||||
assert germany_bot.priority_contacts[0] == "FRANCE" # France is Germany's main contact
|
||||
|
||||
# Check list lengths
|
||||
assert len(france_bot.priority_contacts) == 3 # Three powers contacted France
|
||||
assert len(england_bot.priority_contacts) == 2 # Two powers contacted England
|
||||
assert len(germany_bot.priority_contacts) == 2 # Two powers contacted Germany
|
||||
|
||||
def test_response_pattern_analysis(self, mock_bot_players):
|
||||
"""Test analysis of response patterns across multiple bots."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Set up different response patterns for each bot
|
||||
# France: Responsive to everyone
|
||||
france_bot = bots["FRANCE"]
|
||||
france_bot.message_counts = {"ENGLAND": 4, "GERMANY": 3}
|
||||
france_bot.response_counts = {"ENGLAND": 4, "GERMANY": 3} # 100% response rate
|
||||
|
||||
# England: Selective responder
|
||||
england_bot = bots["ENGLAND"]
|
||||
england_bot.message_counts = {"FRANCE": 6, "GERMANY": 2}
|
||||
england_bot.response_counts = {"FRANCE": 6, "GERMANY": 0} # Only responds to France
|
||||
|
||||
# Germany: Poor responder
|
||||
germany_bot = bots["GERMANY"]
|
||||
germany_bot.message_counts = {"FRANCE": 5, "ENGLAND": 3}
|
||||
germany_bot.response_counts = {"FRANCE": 2, "ENGLAND": 1} # Low response rates
|
||||
|
||||
# Mock client powers for statistics
|
||||
for power, bot in bots.items():
|
||||
mock_powers = {p: MagicMock(is_eliminated=lambda: False) for p in ["FRANCE", "ENGLAND", "GERMANY", "ITALY"]}
|
||||
bot.client.powers = mock_powers
|
||||
|
||||
# Generate statistics for each bot
|
||||
stats = {}
|
||||
for power, bot in bots.items():
|
||||
stats[power] = bot.get_message_statistics()
|
||||
|
||||
# Analyze response patterns
|
||||
france_stats = stats["FRANCE"]
|
||||
england_stats = stats["ENGLAND"]
|
||||
germany_stats = stats["GERMANY"]
|
||||
|
||||
# France should have high overall response rate
|
||||
france_avg_response = sum(france_stats["response_rate_by_power"].values()) / len(france_stats["response_rate_by_power"])
|
||||
assert france_avg_response == 1.0 # Perfect responder
|
||||
|
||||
# England should have selective response pattern
|
||||
assert england_stats["response_rate_by_power"]["FRANCE"] == 1.0 # Always responds to France
|
||||
assert england_stats["response_rate_by_power"]["GERMANY"] == 0.0 # Never responds to Germany
|
||||
|
||||
# Germany should have low overall response rate
|
||||
germany_avg_response = sum(germany_stats["response_rate_by_power"].values()) / len(germany_stats["response_rate_by_power"])
|
||||
assert germany_avg_response < 0.5 # Poor overall response rate
|
||||
|
||||
def test_message_history_consistency_across_bots(self, mock_bot_players):
|
||||
"""Test that message history tracking is consistent across different bots."""
|
||||
bots = mock_bot_players
|
||||
|
||||
# Simulate the same set of messages being processed by different bots
|
||||
# (as they would see them in a real game)
|
||||
|
||||
from diplomacy.engine.message import Message
|
||||
|
||||
shared_messages = [
|
||||
Message(sender="FRANCE", recipient="ENGLAND", message="Alliance proposal", phase="S1901M"),
|
||||
Message(sender="ENGLAND", recipient="FRANCE", message="I accept", phase="S1901M"),
|
||||
Message(sender="GERMANY", recipient="FRANCE", message="What about me?", phase="S1901M"),
|
||||
Message(sender="FRANCE", recipient="GERMANY", message="You're welcome too", phase="S1901M"),
|
||||
]
|
||||
|
||||
# Each bot processes messages relevant to them
|
||||
for message in shared_messages:
|
||||
for power, bot in bots.items():
|
||||
# Add message to game history (all bots see all messages)
|
||||
bot.game_history.add_message(
|
||||
phase_name=message.phase,
|
||||
sender=message.sender,
|
||||
recipient=message.recipient,
|
||||
message_content=message.message,
|
||||
)
|
||||
|
||||
# Track messages directed at this bot
|
||||
if message.recipient == power and message.sender != power:
|
||||
bot.message_counts[message.sender] = bot.message_counts.get(message.sender, 0) + 1
|
||||
bot._update_priority_contacts()
|
||||
|
||||
# Verify that each bot has tracked messages correctly
|
||||
# France should have received 2 messages (from England and Germany)
|
||||
assert bots["FRANCE"].message_counts["ENGLAND"] == 1
|
||||
assert bots["FRANCE"].message_counts["GERMANY"] == 1
|
||||
|
||||
# England should have received 1 message (from France)
|
||||
assert bots["ENGLAND"].message_counts["FRANCE"] == 1
|
||||
|
||||
# Germany should have received 1 message (from France)
|
||||
assert bots["GERMANY"].message_counts["FRANCE"] == 1
|
||||
|
||||
# Check that game history is consistent (all bots see all 4 messages)
|
||||
for power, bot in bots.items():
|
||||
all_messages = bot.game_history.get_messages_for_phase("S1901M")
|
||||
assert len(all_messages) == 4, f"{power} should see all 4 messages in game history"
|
||||
433
bot_client/tests/test_negotiation_rounds.py
Normal file
433
bot_client/tests/test_negotiation_rounds.py
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
"""
|
||||
Integration tests for negotiation round functionality.
|
||||
|
||||
These tests verify that the negotiation system can properly:
|
||||
1. Conduct strategic negotiation rounds with message targeting
|
||||
2. Analyze recent messages for targeting decisions
|
||||
3. Handle negotiation timing and coordination
|
||||
4. Determine negotiation participation appropriately
|
||||
|
||||
Tests use mocked AI responses for predictable behavior while testing
|
||||
the real negotiation coordination logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from typing import Dict, List
|
||||
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
from websocket_negotiations import (
|
||||
conduct_strategic_negotiation_round,
|
||||
analyze_recent_messages_for_targeting,
|
||||
should_participate_in_negotiations,
|
||||
get_negotiation_delay,
|
||||
)
|
||||
from ai_diplomacy.agent import DiplomacyAgent
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
|
||||
class TestMessageTargeting:
|
||||
"""Test strategic message targeting analysis."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_with_messages(self, fake_server):
|
||||
"""Client with pre-populated message history for targeting tests."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("targeting_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="FRANCE",
|
||||
n_controls=7, # Full game for realistic targeting
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Simulate message history with different activity levels
|
||||
# Note: In a real scenario, these would come from other players
|
||||
# For testing, we'll add them to the game's message history directly
|
||||
|
||||
yield client
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_analyze_recent_messages_empty_history(self, client_with_messages):
|
||||
"""Test targeting analysis with no message history."""
|
||||
client = client_with_messages
|
||||
|
||||
# Get targeting priority with empty message history
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
|
||||
|
||||
# Should return all active powers (excluding FRANCE itself)
|
||||
assert isinstance(targets, list)
|
||||
# Should not include FRANCE itself
|
||||
assert "FRANCE" not in targets
|
||||
# Should include other major powers
|
||||
expected_powers = {"ENGLAND", "GERMANY", "ITALY", "AUSTRIA", "RUSSIA", "TURKEY"}
|
||||
assert len(set(targets) & expected_powers) > 0
|
||||
|
||||
async def test_analyze_recent_messages_with_activity(self, client_with_messages):
|
||||
"""Test targeting analysis with simulated message activity."""
|
||||
client = client_with_messages
|
||||
|
||||
# Simulate some message activity by sending messages
|
||||
# This tests the actual message retrieval and analysis logic
|
||||
|
||||
# Send messages to create activity patterns
|
||||
await client.send_message("ENGLAND", "FRANCE", "Hello France!")
|
||||
await client.send_message("ENGLAND", "FRANCE", "Another message from England")
|
||||
await client.send_message("GERMANY", "FRANCE", "Message from Germany")
|
||||
await client.send_message("ITALY", "GLOBAL", "Global message from Italy")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await client.synchronize()
|
||||
|
||||
# Now analyze targeting
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="FRANCE", max_messages=20)
|
||||
|
||||
# ENGLAND should be prioritized (sent 2 direct messages to FRANCE)
|
||||
# GERMANY should be second (sent 1 direct message to FRANCE)
|
||||
# Others should follow
|
||||
assert isinstance(targets, list)
|
||||
assert "FRANCE" not in targets
|
||||
|
||||
# England should be first in priority (most direct messages to FRANCE)
|
||||
if "ENGLAND" in targets:
|
||||
england_index = targets.index("ENGLAND")
|
||||
if "GERMANY" in targets:
|
||||
germany_index = targets.index("GERMANY")
|
||||
assert england_index < germany_index, "England should be prioritized over Germany"
|
||||
|
||||
async def test_analyze_messages_error_handling(self, client_with_messages):
|
||||
"""Test error handling in message analysis."""
|
||||
client = client_with_messages
|
||||
|
||||
# Test with invalid power name
|
||||
targets = await analyze_recent_messages_for_targeting(client=client, power_name="INVALID_POWER", max_messages=20)
|
||||
|
||||
# Should still return a valid list (fallback behavior)
|
||||
assert isinstance(targets, list)
|
||||
|
||||
|
||||
class TestNegotiationTiming:
|
||||
"""Test negotiation delay and timing logic."""
|
||||
|
||||
def test_negotiation_delay_calculation(self):
|
||||
"""Test that negotiation delays are calculated correctly."""
|
||||
|
||||
# Test first round gets extra time
|
||||
first_round_delay = get_negotiation_delay(round_number=1, total_rounds=3)
|
||||
base_delay = get_negotiation_delay(round_number=2, total_rounds=3)
|
||||
final_round_delay = get_negotiation_delay(round_number=3, total_rounds=3)
|
||||
|
||||
assert first_round_delay > base_delay, "First round should have longer delay"
|
||||
assert final_round_delay < base_delay, "Final round should have shorter delay"
|
||||
|
||||
# All delays should be positive
|
||||
assert first_round_delay > 0
|
||||
assert base_delay > 0
|
||||
assert final_round_delay > 0
|
||||
|
||||
def test_negotiation_delay_edge_cases(self):
|
||||
"""Test delay calculation with edge cases."""
|
||||
|
||||
# Single round
|
||||
single_delay = get_negotiation_delay(round_number=1, total_rounds=1)
|
||||
assert single_delay > 0
|
||||
|
||||
# Many rounds
|
||||
many_rounds_delay = get_negotiation_delay(round_number=5, total_rounds=10)
|
||||
assert many_rounds_delay > 0
|
||||
|
||||
|
||||
class TestNegotiationParticipation:
|
||||
"""Test logic for determining negotiation participation."""
|
||||
|
||||
@pytest.fixture
|
||||
async def client_and_mocked_agent(self, fake_server):
|
||||
"""Client with mocked agent for participation testing."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("participation_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS"],
|
||||
power_name="RUSSIA",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Create mock agent
|
||||
mock_agent = MagicMock(spec=DiplomacyAgent)
|
||||
mock_agent.power_name = "RUSSIA"
|
||||
|
||||
yield {"client": client, "agent": mock_agent}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_should_participate_eliminated_power(self, client_and_mocked_agent):
|
||||
"""Test that eliminated powers don't participate in negotiations."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock the power as eliminated
|
||||
with patch.object(client, "get_power") as mock_get_power:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = True
|
||||
mock_get_power.return_value = mock_power
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_no_orders(self, client_and_mocked_agent):
|
||||
"""Test that powers with no orderable locations don't negotiate."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock the power as not eliminated but with no orderable locations
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = [] # No possible orders
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_non_movement_phase(self, client_and_mocked_agent):
|
||||
"""Test that powers don't negotiate in non-movement phases."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock retreat phase
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch.object(client, "get_current_short_phase") as mock_phase,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Moscow - Hold"] # Has orders
|
||||
mock_phase.return_value = "S1901R" # Retreat phase
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is False
|
||||
|
||||
async def test_should_participate_movement_phase_with_orders(self, client_and_mocked_agent):
|
||||
"""Test that active powers with orders participate in movement phases."""
|
||||
client = client_and_mocked_agent["client"]
|
||||
agent = client_and_mocked_agent["agent"]
|
||||
|
||||
# Mock movement phase with orderable locations
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch.object(client, "get_current_short_phase") as mock_phase,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Moscow - Hold", "F Sevastopol - Black Sea"]
|
||||
mock_phase.return_value = "S1901M" # Movement phase
|
||||
|
||||
should_participate = await should_participate_in_negotiations(client, agent)
|
||||
assert should_participate is True
|
||||
|
||||
|
||||
class TestStrategicNegotiationRound:
|
||||
"""Test the complete strategic negotiation round functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
async def negotiation_setup(self, fake_server):
|
||||
"""Setup for testing complete negotiation rounds."""
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
|
||||
try:
|
||||
await client.connect_and_authenticate("negotiation_user", "password")
|
||||
await client.create_game(
|
||||
map_name="standard",
|
||||
rules=["IGNORE_ERRORS", "POWER_CHOICE"],
|
||||
power_name="TURKEY",
|
||||
n_controls=1,
|
||||
)
|
||||
await client.synchronize()
|
||||
|
||||
# Create mock agent with necessary attributes
|
||||
mock_agent = MagicMock(spec=DiplomacyAgent)
|
||||
mock_agent.power_name = "TURKEY"
|
||||
mock_agent.goals = "Expand into the Mediterranean"
|
||||
mock_agent.relationships = {}
|
||||
mock_agent.format_private_diary_for_prompt.return_value = "Test diary entry"
|
||||
|
||||
# Mock AI client for message generation
|
||||
mock_ai_client = AsyncMock()
|
||||
mock_agent.client = mock_ai_client
|
||||
|
||||
# Create game history
|
||||
game_history = GameHistory()
|
||||
|
||||
# Mock error stats
|
||||
error_stats = {"test_model": {"conversation_errors": 0, "order_decoding_errors": 0}}
|
||||
|
||||
yield {"client": client, "agent": mock_agent, "game_history": game_history, "error_stats": error_stats}
|
||||
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test_successful_negotiation_round(self, negotiation_setup):
|
||||
"""Test a successful negotiation round with message generation."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI response for message generation
|
||||
mock_messages = [
|
||||
{"content": "Greetings! Turkey seeks peaceful relations.", "message_type": "global"},
|
||||
{"content": "Russia, shall we coordinate our efforts?", "message_type": "private", "recipient": "RUSSIA"},
|
||||
]
|
||||
agent.client.get_conversation_reply.return_value = mock_messages
|
||||
|
||||
# Mock power and orders
|
||||
with (
|
||||
patch.object(client, "get_power") as mock_get_power,
|
||||
patch("websocket_negotiations.gather_possible_orders") as mock_orders,
|
||||
patch.object(client, "send_message") as mock_send,
|
||||
):
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
# Should have called send_message for each generated message
|
||||
assert mock_send.call_count == len(mock_messages)
|
||||
|
||||
async def test_negotiation_round_no_messages(self, negotiation_setup):
|
||||
"""Test negotiation round when AI generates no messages."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI response with no messages
|
||||
agent.client.get_conversation_reply.return_value = []
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
|
||||
async def test_negotiation_round_ai_error(self, negotiation_setup):
|
||||
"""Test negotiation round when AI client raises an exception."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock AI client to raise exception
|
||||
agent.client.get_conversation_reply.side_effect = Exception("AI service unavailable")
|
||||
agent.client.model_name = "test_model"
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round
|
||||
success = await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=1,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert success is False
|
||||
# Error should be tracked in statistics
|
||||
assert error_stats["test_model"]["conversation_errors"] == 1
|
||||
|
||||
async def test_negotiation_round_with_targeting(self, negotiation_setup):
|
||||
"""Test that negotiation round uses strategic targeting."""
|
||||
setup = negotiation_setup
|
||||
client = setup["client"]
|
||||
agent = setup["agent"]
|
||||
game_history = setup["game_history"]
|
||||
error_stats = setup["error_stats"]
|
||||
|
||||
# Mock message targeting analysis
|
||||
with patch("websocket_negotiations.analyze_recent_messages_for_targeting") as mock_targeting:
|
||||
mock_targeting.return_value = ["RUSSIA", "AUSTRIA", "ITALY"]
|
||||
|
||||
# Mock AI response
|
||||
mock_messages = [{"content": "Test message", "message_type": "global"}]
|
||||
agent.client.get_conversation_reply.return_value = mock_messages
|
||||
|
||||
# Mock power and orders
|
||||
with patch.object(client, "get_power") as mock_get_power, patch("websocket_negotiations.gather_possible_orders") as mock_orders:
|
||||
mock_power = MagicMock()
|
||||
mock_power.is_eliminated.return_value = False
|
||||
mock_get_power.return_value = mock_power
|
||||
mock_orders.return_value = ["A Constantinople - Hold"]
|
||||
|
||||
# Run negotiation round 2 (should use targeting)
|
||||
await conduct_strategic_negotiation_round(
|
||||
client=client,
|
||||
agent=agent,
|
||||
game_history=game_history,
|
||||
model_error_stats=error_stats,
|
||||
log_file_path="/tmp/test_log.txt",
|
||||
round_number=2, # Round > 1 should use targeting
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
# Should have called the targeting analysis
|
||||
mock_targeting.assert_called_once_with(client, "TURKEY")
|
||||
|
|
@ -19,16 +19,14 @@ from models import (
|
|||
serialize_message,
|
||||
parse_message,
|
||||
)
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
from websocket_diplomacy_client import WebSocketDiplomacyClient
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Test authentication flow and message handling."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_successful_authentication(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
async def test_successful_authentication(self, client: WebSocketDiplomacyClient, helpers):
|
||||
"""Test successful authentication with valid credentials."""
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
helpers.assert_valid_token(client.token)
|
||||
|
|
@ -44,16 +42,12 @@ class TestAuthentication:
|
|||
assert client.token is None
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_authentication_message_format(
|
||||
self, credentials: Dict[str, str], client: TypedWebSocketDiplomacyClient
|
||||
):
|
||||
async def test_authentication_message_format(self, credentials: Dict[str, str], client: WebSocketDiplomacyClient):
|
||||
"""Test that authentication messages are properly formatted."""
|
||||
await client.connect_and_authenticate(**credentials)
|
||||
|
||||
# Create a sign-in request manually to test message format
|
||||
request = SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test_user", password="test_password"
|
||||
)
|
||||
request = SignInRequest(request_id=str(uuid.uuid4()), username="test_user", password="test_password")
|
||||
|
||||
# Verify the request serializes correctly
|
||||
message_data = serialize_message(request)
|
||||
|
|
@ -74,9 +68,7 @@ class TestGameOperations:
|
|||
@pytest.mark.websocket
|
||||
async def test_create_game(self, authenticated_client, helpers):
|
||||
"""Test creating a new game."""
|
||||
game_data = await authenticated_client.create_game(
|
||||
power_name="FRANCE", n_controls=1
|
||||
)
|
||||
game_data = await authenticated_client.create_game(power_name="FRANCE", n_controls=1)
|
||||
|
||||
helpers.assert_valid_game_data(game_data)
|
||||
assert authenticated_client.game_id is not None
|
||||
|
|
@ -98,9 +90,7 @@ class TestGameOperations:
|
|||
assert authenticated_client.game_id in game_ids
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_join_existing_game(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
async def test_join_existing_game(self, client: WebSocketDiplomacyClient, helpers):
|
||||
"""Test joining an existing game."""
|
||||
# Connect and authenticate a first client
|
||||
await client.connect()
|
||||
|
|
@ -130,9 +120,7 @@ class TestGameOperations:
|
|||
async def test_join_nonexistent_game(self, authenticated_client):
|
||||
"""Test joining a game that doesn't exist."""
|
||||
with pytest.raises(ValueError, match="Game .* not found"):
|
||||
await authenticated_client.join_game(
|
||||
"NONEXISTENT_GAME", power_name="FRANCE"
|
||||
)
|
||||
await authenticated_client.join_game("NONEXISTENT_GAME", power_name="FRANCE")
|
||||
|
||||
|
||||
class TestGamePlay:
|
||||
|
|
@ -246,9 +234,7 @@ class TestMessageValidation:
|
|||
"""Test that messages can be serialized and parsed correctly."""
|
||||
# Test various message types
|
||||
messages = [
|
||||
SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test", password="pass"
|
||||
),
|
||||
SignInRequest(request_id=str(uuid.uuid4()), username="test", password="pass"),
|
||||
CreateGameRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token="test_token",
|
||||
|
|
@ -289,7 +275,7 @@ class TestConcurrentOperations:
|
|||
try:
|
||||
# Create multiple clients
|
||||
for i in range(3):
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
client = WebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
await client.connect()
|
||||
await client.authenticate("test_user", "test_password")
|
||||
clients.append(client)
|
||||
|
|
@ -300,9 +286,7 @@ class TestConcurrentOperations:
|
|||
|
||||
# Each can create games independently
|
||||
for i, client in enumerate(clients):
|
||||
game_data = await client.create_game(
|
||||
power_name="FRANCE" if i == 0 else None, n_controls=1
|
||||
)
|
||||
game_data = await client.create_game(power_name="FRANCE" if i == 0 else None, n_controls=1)
|
||||
assert game_data["game_id"] is not None
|
||||
|
||||
finally:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue