mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-05-02 17:46:00 +00:00
WIP: Continuing to add pieces for the LMs to play over websocket.
Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
parent
939da088a5
commit
cc7fb1ccd3
11 changed files with 1101 additions and 107 deletions
1
bot_client/tests/__init__.py
Normal file
1
bot_client/tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Test package initialization
|
||||
123
bot_client/tests/conftest.py
Normal file
123
bot_client/tests/conftest.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""
|
||||
Pytest configuration and fixtures for WebSocket testing.
|
||||
|
||||
This module provides shared fixtures and utilities for testing our
|
||||
WebSocket client implementations against a fake server.
|
||||
"""
|
||||
|
||||
# Add parent directory to path for ai_diplomacy imports (runtime only)
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from typing import AsyncGenerator, Dict, Generator
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
# Add the bot_client directory to the path so we can import our modules
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from .test_fake_websocket_server import FakeServerManager, FakeWebSocketServer
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
async def fake_server() -> AsyncGenerator[FakeWebSocketServer, None]:
|
||||
"""
|
||||
Fixture that provides a running fake WebSocket server for testing.
|
||||
|
||||
The server is automatically started before the test and stopped after.
|
||||
Uses port 8433 to avoid conflicts with real servers on 8432.
|
||||
"""
|
||||
async with FakeServerManager("localhost", 8433) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credentials() -> Generator[Dict[str, str], None]:
|
||||
yield {"username": "test_user", "password": "test_password"}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
|
||||
"""
|
||||
Fixture that provides a TypedWebSocketDiplomacyClient instance.
|
||||
|
||||
The client is configured to connect to the fake server on port 8433.
|
||||
"""
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
yield client
|
||||
|
||||
# Cleanup: close the client connection
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def authenticated_client(fake_server, client):
|
||||
"""
|
||||
Fixture that provides an authenticated TypedWebSocketDiplomacyClient.
|
||||
|
||||
This client is already connected and authenticated, ready for testing
|
||||
game operations.
|
||||
"""
|
||||
await client.connect()
|
||||
token = await client.authenticate("test_user", "test_password")
|
||||
assert token is not None
|
||||
assert client.token == token
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client_with_game(authenticated_client):
|
||||
"""
|
||||
Fixture that provides an authenticated client with a created game.
|
||||
|
||||
The client is in a game as FRANCE, ready for order submission and
|
||||
game interaction testing.
|
||||
"""
|
||||
game_data = await authenticated_client.create_game(
|
||||
power_name="FRANCE",
|
||||
n_controls=1, # For testing
|
||||
)
|
||||
assert game_data is not None
|
||||
assert authenticated_client.game_id is not None
|
||||
yield authenticated_client
|
||||
|
||||
|
||||
# Test utilities
|
||||
class TestHelpers:
|
||||
"""Helper methods for testing WebSocket interactions."""
|
||||
|
||||
@staticmethod
|
||||
def assert_valid_token(token: str):
|
||||
"""Assert that a token looks valid."""
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 10
|
||||
assert "fake_token_" in token
|
||||
|
||||
@staticmethod
|
||||
def assert_valid_game_data(game_data: dict):
|
||||
"""Assert that game data has expected structure."""
|
||||
assert isinstance(game_data, dict)
|
||||
assert "game_id" in game_data
|
||||
assert "phase" in game_data
|
||||
assert "powers" in game_data
|
||||
assert "map_name" in game_data
|
||||
|
||||
@staticmethod
|
||||
def assert_valid_games_list(games: list):
|
||||
"""Assert that games list has expected structure."""
|
||||
assert isinstance(games, list)
|
||||
for game in games:
|
||||
assert "game_id" in game
|
||||
assert "status" in game
|
||||
assert "phase" in game
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helpers():
|
||||
"""Fixture that provides test helper methods."""
|
||||
return TestHelpers
|
||||
378
bot_client/tests/test_fake_websocket_server.py
Normal file
378
bot_client/tests/test_fake_websocket_server.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
"""
|
||||
Fake WebSocket Server for Testing
|
||||
|
||||
This module provides a mock Diplomacy WebSocket server that implements the protocol
|
||||
from WEBSOCKET.md for testing purposes. It responds to typed messages with appropriate
|
||||
responses, allowing us to test our client implementation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import websockets
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from loguru import logger
|
||||
|
||||
from models import (
|
||||
# Request types
|
||||
SignInRequest,
|
||||
CreateGameRequest,
|
||||
JoinGameRequest,
|
||||
ListGamesRequest,
|
||||
SetOrdersRequest,
|
||||
ProcessGameRequest,
|
||||
GetAllPossibleOrdersRequest,
|
||||
# Response types
|
||||
DataTokenResponse,
|
||||
DataGameResponse,
|
||||
DataGamesResponse,
|
||||
DataPossibleOrdersResponse,
|
||||
OkResponse,
|
||||
ErrorResponse,
|
||||
# Notifications
|
||||
parse_message,
|
||||
serialize_message,
|
||||
)
|
||||
|
||||
|
||||
class FakeWebSocketServer:
|
||||
"""
|
||||
A fake WebSocket server that implements the Diplomacy protocol for testing.
|
||||
|
||||
This server maintains minimal state and responds to requests with valid
|
||||
responses according to the WEBSOCKET.md protocol specification.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.connected_clients: Set[websockets.WebSocketServerProtocol] = set()
|
||||
|
||||
# Mock server state
|
||||
self.users = {
|
||||
"test_user": "test_password",
|
||||
"ai_player": "password",
|
||||
"player1": "password"
|
||||
}
|
||||
self.tokens: Dict[str, str] = {} # token -> username
|
||||
self.games: Dict[str, Dict[str, Any]] = {}
|
||||
self.game_counter = 1
|
||||
|
||||
async def start(self):
|
||||
"""Start the fake WebSocket server."""
|
||||
logger.info(f"Starting fake WebSocket server on {self.host}:{self.port}")
|
||||
self.server = await websockets.serve(
|
||||
self.handle_client,
|
||||
self.host,
|
||||
self.port
|
||||
)
|
||||
logger.info("Fake WebSocket server started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the fake WebSocket server."""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("Fake WebSocket server stopped")
|
||||
|
||||
async def handle_client(self, websocket):
|
||||
"""Handle a new WebSocket client connection."""
|
||||
self.connected_clients.add(websocket)
|
||||
logger.info(f"Client connected from {websocket.remote_address}")
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
await self.handle_message(websocket, message)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info("Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling client: {e}")
|
||||
finally:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def handle_message(self, websocket, message_text: str):
|
||||
"""Handle an incoming message from a client."""
|
||||
try:
|
||||
# Parse the raw message
|
||||
data = json.loads(message_text)
|
||||
request = parse_message(data)
|
||||
|
||||
# Generate response based on request type
|
||||
response = await self.generate_response(request)
|
||||
|
||||
# Send response back to client
|
||||
if response:
|
||||
response_data = serialize_message(response)
|
||||
await websocket.send(json.dumps(response_data))
|
||||
logger.debug(f"Sent response: {response.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
logger.debug(f"Raw message: {message_text}")
|
||||
|
||||
# Send error response if we can extract request_id
|
||||
try:
|
||||
data = json.loads(message_text)
|
||||
request_id = data.get("request_id", str(uuid.uuid4()))
|
||||
error_response = ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="PARSING_ERROR",
|
||||
message=str(e)
|
||||
)
|
||||
error_data = serialize_message(error_response)
|
||||
await websocket.send(json.dumps(error_data))
|
||||
except:
|
||||
logger.error("Could not send error response")
|
||||
|
||||
async def generate_response(self, request) -> Optional[Any]:
|
||||
"""Generate appropriate response for a request."""
|
||||
logger.info(f"Processing request: {request.name}")
|
||||
|
||||
# Authentication requests
|
||||
if isinstance(request, SignInRequest):
|
||||
return await self.handle_sign_in(request)
|
||||
|
||||
# Channel-level requests (require token)
|
||||
elif isinstance(request, CreateGameRequest):
|
||||
return await self.handle_create_game(request)
|
||||
elif isinstance(request, JoinGameRequest):
|
||||
return await self.handle_join_game(request)
|
||||
elif isinstance(request, ListGamesRequest):
|
||||
return await self.handle_list_games(request)
|
||||
|
||||
# Game-level requests (require token + game context)
|
||||
elif isinstance(request, SetOrdersRequest):
|
||||
return await self.handle_set_orders(request)
|
||||
elif isinstance(request, ProcessGameRequest):
|
||||
return await self.handle_process_game(request)
|
||||
elif isinstance(request, GetAllPossibleOrdersRequest):
|
||||
return await self.handle_get_possible_orders(request)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unhandled request type: {type(request)}")
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="UNSUPPORTED_REQUEST",
|
||||
message=f"Request type {request.name} not supported by fake server"
|
||||
)
|
||||
|
||||
async def handle_sign_in(self, request: SignInRequest) -> Any:
|
||||
"""Handle authentication request."""
|
||||
username = request.username
|
||||
password = request.password
|
||||
|
||||
if username in self.users and self.users[username] == password:
|
||||
# Generate auth token
|
||||
token = f"fake_token_{uuid.uuid4().hex[:16]}"
|
||||
self.tokens[token] = username
|
||||
|
||||
return DataTokenResponse(
|
||||
request_id=request.request_id,
|
||||
data=token
|
||||
)
|
||||
else:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid username or password"
|
||||
)
|
||||
|
||||
async def handle_create_game(self, request: CreateGameRequest) -> Any:
|
||||
"""Handle game creation request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
# Create a new game
|
||||
game_id = f"GAME_{self.game_counter:04d}"
|
||||
self.game_counter += 1
|
||||
|
||||
game_data = {
|
||||
"game_id": game_id,
|
||||
"map_name": request.map_name,
|
||||
"rules": request.rules,
|
||||
"phase": "S1901M",
|
||||
"status": "FORMING",
|
||||
"n_controls": request.n_controls,
|
||||
"powers": {
|
||||
"AUSTRIA": {"units": [], "centers": ["VIE", "BUD", "TRI"], "is_eliminated": False},
|
||||
"ENGLAND": {"units": [], "centers": ["EDI", "LVP", "LON"], "is_eliminated": False},
|
||||
"FRANCE": {"units": [], "centers": ["PAR", "BRE", "MAR"], "is_eliminated": False},
|
||||
"GERMANY": {"units": [], "centers": ["BER", "MUN", "KIE"], "is_eliminated": False},
|
||||
"ITALY": {"units": [], "centers": ["ROM", "NAP", "VEN"], "is_eliminated": False},
|
||||
"RUSSIA": {"units": [], "centers": ["MOS", "SEV", "STP", "WAR"], "is_eliminated": False},
|
||||
"TURKEY": {"units": [], "centers": ["ANK", "CON", "SMY"], "is_eliminated": False},
|
||||
},
|
||||
"controlled_powers": {}
|
||||
}
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
async def handle_join_game(self, request: JoinGameRequest) -> Any:
|
||||
"""Handle game join request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
game_data = self.games[game_id].copy()
|
||||
|
||||
# If a specific power was requested, assign it
|
||||
if request.power_name:
|
||||
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
|
||||
self.games[game_id] = game_data
|
||||
|
||||
return DataGameResponse(
|
||||
request_id=request.request_id,
|
||||
data=game_data
|
||||
)
|
||||
|
||||
async def handle_list_games(self, request: ListGamesRequest) -> Any:
|
||||
"""Handle list games request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
# Return simplified game info
|
||||
games_list = []
|
||||
for game_id, game_data in self.games.items():
|
||||
games_list.append({
|
||||
"game_id": game_id,
|
||||
"map_name": game_data["map_name"],
|
||||
"status": game_data["status"],
|
||||
"phase": game_data["phase"],
|
||||
"n_controls": game_data["n_controls"]
|
||||
})
|
||||
|
||||
return DataGamesResponse(
|
||||
request_id=request.request_id,
|
||||
data=games_list
|
||||
)
|
||||
|
||||
async def handle_set_orders(self, request: SetOrdersRequest) -> Any:
|
||||
"""Handle set orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
# In a real server, we'd validate the orders and store them
|
||||
# For testing, we just acknowledge receipt
|
||||
logger.info(f"Orders received for {request.game_role}: {request.orders}")
|
||||
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
async def handle_process_game(self, request: ProcessGameRequest) -> Any:
|
||||
"""Handle process game request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
game_id = request.game_id
|
||||
if game_id not in self.games:
|
||||
return ErrorResponse(
|
||||
request_id=request.request_id,
|
||||
error_type="GAME_NOT_FOUND",
|
||||
message=f"Game {game_id} not found"
|
||||
)
|
||||
|
||||
# Simulate game processing
|
||||
game_data = self.games[game_id]
|
||||
current_phase = game_data["phase"]
|
||||
|
||||
# Simple phase progression
|
||||
if current_phase == "S1901M":
|
||||
game_data["phase"] = "F1901M"
|
||||
elif current_phase == "F1901M":
|
||||
game_data["phase"] = "W1901A"
|
||||
else:
|
||||
# For testing, cycle back to start
|
||||
game_data["phase"] = "S1902M"
|
||||
|
||||
self.games[game_id] = game_data
|
||||
|
||||
# Send notification to all clients (in real implementation)
|
||||
# For testing, we'll just return OK
|
||||
return OkResponse(request_id=request.request_id)
|
||||
|
||||
async def handle_get_possible_orders(self, request: GetAllPossibleOrdersRequest) -> Any:
|
||||
"""Handle get possible orders request."""
|
||||
if not self.validate_token(request.token):
|
||||
return self.create_auth_error(request.request_id)
|
||||
|
||||
# Return mock possible orders
|
||||
possible_orders = {
|
||||
"PAR": ["A PAR H", "A PAR - BUR", "A PAR - PIC", "A PAR - GAS"],
|
||||
"BRE": ["F BRE H", "F BRE - MAO", "F BRE - ENG", "F BRE - PIC"],
|
||||
"MAR": ["A MAR H", "A MAR - GAS", "A MAR - SPA", "A MAR - PIE"],
|
||||
}
|
||||
|
||||
return DataPossibleOrdersResponse(
|
||||
request_id=request.request_id,
|
||||
data=possible_orders
|
||||
)
|
||||
|
||||
def validate_token(self, token: str) -> bool:
|
||||
"""Validate an authentication token."""
|
||||
return token in self.tokens
|
||||
|
||||
def create_auth_error(self, request_id: str) -> ErrorResponse:
|
||||
"""Create a standard authentication error response."""
|
||||
return ErrorResponse(
|
||||
request_id=request_id,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
message="Invalid or missing authentication token"
|
||||
)
|
||||
|
||||
|
||||
class FakeServerManager:
|
||||
"""Context manager for the fake WebSocket server."""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8433):
|
||||
self.server = FakeWebSocketServer(host, port)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.server.start()
|
||||
# Give the server a moment to start
|
||||
await asyncio.sleep(0.1)
|
||||
return self.server
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.server.stop()
|
||||
|
||||
|
||||
# Utility function for tests
|
||||
async def run_fake_server(host: str = "localhost", port: int = 8433):
|
||||
"""Run the fake server (for manual testing)."""
|
||||
async with FakeServerManager(host, port) as server:
|
||||
logger.info(f"Fake server running on {host}:{port}")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
try:
|
||||
# Keep running until interrupted
|
||||
await asyncio.Event().wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the fake server for manual testing
|
||||
asyncio.run(run_fake_server())
|
||||
324
bot_client/tests/test_websocket_client.py
Normal file
324
bot_client/tests/test_websocket_client.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
"""
|
||||
Tests for WebSocket client functionality using the fake server.
|
||||
|
||||
These tests verify that our typed WebSocket client can properly:
|
||||
1. Connect to a WebSocket server
|
||||
2. Send and receive typed messages according to the protocol
|
||||
3. Handle authentication, game creation, and game operations
|
||||
4. Validate message format and content
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
from models import (
|
||||
SignInRequest,
|
||||
CreateGameRequest,
|
||||
SetOrdersRequest,
|
||||
serialize_message,
|
||||
parse_message,
|
||||
)
|
||||
from typed_websocket_client import TypedWebSocketDiplomacyClient
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Test authentication flow and message handling."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_successful_authentication(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
"""Test successful authentication with valid credentials."""
|
||||
await client.connect_and_authenticate("test_user", "test_password")
|
||||
helpers.assert_valid_token(client.token)
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_failed_authentication(self, client):
|
||||
"""Test authentication failure with invalid credentials."""
|
||||
await client.connect()
|
||||
|
||||
with pytest.raises(ValueError, match="Authentication failed"):
|
||||
await client.authenticate("invalid_user", "wrong_password")
|
||||
|
||||
assert client.token is None
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_authentication_message_format(
|
||||
self, credentials: Dict[str, str], client: TypedWebSocketDiplomacyClient
|
||||
):
|
||||
"""Test that authentication messages are properly formatted."""
|
||||
await client.connect_and_authenticate(**credentials)
|
||||
|
||||
# Create a sign-in request manually to test message format
|
||||
request = SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test_user", password="test_password"
|
||||
)
|
||||
|
||||
# Verify the request serializes correctly
|
||||
message_data = serialize_message(request)
|
||||
assert message_data["name"] == "sign_in"
|
||||
assert message_data["username"] == "test_user"
|
||||
assert message_data["password"] == "test_password"
|
||||
assert "request_id" in message_data
|
||||
|
||||
# Test that we can parse it back
|
||||
parsed = parse_message(message_data)
|
||||
assert isinstance(parsed, SignInRequest)
|
||||
assert parsed.username == "test_user"
|
||||
|
||||
|
||||
class TestGameOperations:
|
||||
"""Test game creation, joining, and basic operations."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_create_game(self, authenticated_client, helpers):
|
||||
"""Test creating a new game."""
|
||||
game_data = await authenticated_client.create_game(
|
||||
power_name="FRANCE", n_controls=1
|
||||
)
|
||||
|
||||
helpers.assert_valid_game_data(game_data)
|
||||
assert authenticated_client.game_id is not None
|
||||
assert authenticated_client.game_role == "FRANCE"
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_list_games(self, authenticated_client, helpers):
|
||||
"""Test listing available games."""
|
||||
# First create a game so there's something to list
|
||||
await authenticated_client.create_game(power_name="ENGLAND", n_controls=1)
|
||||
|
||||
games = await authenticated_client.list_games()
|
||||
|
||||
helpers.assert_valid_games_list(games)
|
||||
assert len(games) >= 1
|
||||
|
||||
# Check that our created game is in the list
|
||||
game_ids = [game["game_id"] for game in games]
|
||||
assert authenticated_client.game_id in game_ids
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_join_existing_game(
|
||||
self, client: TypedWebSocketDiplomacyClient, helpers
|
||||
):
|
||||
"""Test joining an existing game."""
|
||||
# Connect and authenticate a first client
|
||||
await client.connect()
|
||||
await client.authenticate("test_user", "test_password")
|
||||
|
||||
# Create a game
|
||||
game_data = await client.create_game(power_name="FRANCE", n_controls=1)
|
||||
game_id = game_data["game_id"]
|
||||
|
||||
# Create a second client to join the game
|
||||
client2 = client.__class__("localhost", 8433, use_ssl=False)
|
||||
try:
|
||||
await client2.connect()
|
||||
await client2.authenticate("ai_player", "password")
|
||||
|
||||
# Join the existing game as a different power
|
||||
joined_game_data = await client2.join_game(game_id, power_name="ENGLAND")
|
||||
|
||||
helpers.assert_valid_game_data(joined_game_data)
|
||||
assert joined_game_data["game_id"] == game_id
|
||||
assert client2.game_role == "ENGLAND"
|
||||
|
||||
finally:
|
||||
await client2.close()
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_join_nonexistent_game(self, authenticated_client):
|
||||
"""Test joining a game that doesn't exist."""
|
||||
with pytest.raises(ValueError, match="Game .* not found"):
|
||||
await authenticated_client.join_game(
|
||||
"NONEXISTENT_GAME", power_name="FRANCE"
|
||||
)
|
||||
|
||||
|
||||
class TestGamePlay:
|
||||
"""Test actual gameplay operations like setting orders and processing."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_set_orders(self, client_with_game):
|
||||
"""Test submitting orders for a power."""
|
||||
orders = ["A PAR H", "F BRE H", "A MAR H"]
|
||||
|
||||
# Should not raise an exception
|
||||
await client_with_game.set_orders("FRANCE", orders)
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_set_empty_orders(self, client_with_game):
|
||||
"""Test submitting empty orders."""
|
||||
await client_with_game.set_orders("FRANCE", [])
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_process_game(self, client_with_game):
|
||||
"""Test processing the game phase."""
|
||||
# Should not raise an exception (fake server allows processing)
|
||||
await client_with_game.process_game()
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_get_possible_orders(self, client_with_game):
|
||||
"""Test getting possible orders for the current phase."""
|
||||
possible_orders = await client_with_game.get_all_possible_orders()
|
||||
|
||||
assert isinstance(possible_orders, dict)
|
||||
# Fake server returns some mock orders
|
||||
assert len(possible_orders) > 0
|
||||
|
||||
# Check format: location -> list of orders
|
||||
for location, orders in possible_orders.items():
|
||||
assert isinstance(location, str)
|
||||
assert isinstance(orders, list)
|
||||
for order in orders:
|
||||
assert isinstance(order, str)
|
||||
|
||||
|
||||
class TestMessageValidation:
|
||||
"""Test message format validation and error handling."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_create_game_message_format(self, authenticated_client):
|
||||
"""Test that create game messages have correct format."""
|
||||
request = CreateGameRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token=authenticated_client.token,
|
||||
map_name="standard",
|
||||
rules=["NO_PRESS", "IGNORE_ERRORS"],
|
||||
power_name="AUSTRIA",
|
||||
n_controls=7,
|
||||
)
|
||||
|
||||
message_data = serialize_message(request)
|
||||
|
||||
# Verify all required fields are present
|
||||
assert message_data["name"] == "create_game"
|
||||
assert message_data["token"] == authenticated_client.token
|
||||
assert message_data["map_name"] == "standard"
|
||||
assert message_data["rules"] == ["NO_PRESS", "IGNORE_ERRORS"]
|
||||
assert message_data["power_name"] == "AUSTRIA"
|
||||
assert message_data["n_controls"] == 7
|
||||
assert "request_id" in message_data
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_set_orders_message_format(self, client_with_game):
|
||||
"""Test that set orders messages have correct format."""
|
||||
request = SetOrdersRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token=client_with_game.token,
|
||||
game_id=client_with_game.game_id,
|
||||
game_role="FRANCE",
|
||||
orders=["A PAR H", "F BRE - ENG"],
|
||||
)
|
||||
|
||||
message_data = serialize_message(request)
|
||||
|
||||
# Verify all required fields are present
|
||||
assert message_data["name"] == "set_orders"
|
||||
assert message_data["token"] == client_with_game.token
|
||||
assert message_data["game_id"] == client_with_game.game_id
|
||||
assert message_data["game_role"] == "FRANCE"
|
||||
assert message_data["orders"] == ["A PAR H", "F BRE - ENG"]
|
||||
assert "request_id" in message_data
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_error_response_handling(self, client):
|
||||
"""Test that error responses are properly handled."""
|
||||
await client.connect()
|
||||
|
||||
# Try to authenticate with wrong credentials
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await client.authenticate("bad_user", "bad_pass")
|
||||
|
||||
assert "Authentication failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_unauthenticated_requests_fail(self, client):
|
||||
"""Test that requests without authentication fail appropriately."""
|
||||
await client.connect()
|
||||
# Don't authenticate
|
||||
|
||||
with pytest.raises(ValueError, match="Must authenticate first"):
|
||||
await client.create_game()
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_message_round_trip(self, authenticated_client):
|
||||
"""Test that messages can be serialized and parsed correctly."""
|
||||
# Test various message types
|
||||
messages = [
|
||||
SignInRequest(
|
||||
request_id=str(uuid.uuid4()), username="test", password="pass"
|
||||
),
|
||||
CreateGameRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token="test_token",
|
||||
map_name="standard",
|
||||
power_name="FRANCE",
|
||||
),
|
||||
SetOrdersRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
token="test_token",
|
||||
game_id="TEST_GAME",
|
||||
game_role="FRANCE",
|
||||
orders=["A PAR H"],
|
||||
),
|
||||
]
|
||||
|
||||
for original_msg in messages:
|
||||
# Serialize to dict
|
||||
serialized = serialize_message(original_msg)
|
||||
|
||||
# Parse back to object
|
||||
parsed_msg = parse_message(serialized)
|
||||
|
||||
# Should be the same type and have same data
|
||||
assert type(parsed_msg) == type(original_msg)
|
||||
assert parsed_msg.name == original_msg.name
|
||||
assert parsed_msg.request_id == original_msg.request_id
|
||||
|
||||
|
||||
class TestConcurrentOperations:
|
||||
"""Test concurrent WebSocket operations."""
|
||||
|
||||
@pytest.mark.websocket
|
||||
@pytest.mark.slow
|
||||
async def test_multiple_concurrent_clients(self, fake_server):
|
||||
"""Test multiple clients connecting simultaneously."""
|
||||
clients = []
|
||||
|
||||
try:
|
||||
# Create multiple clients
|
||||
for i in range(3):
|
||||
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
|
||||
await client.connect()
|
||||
await client.authenticate("test_user", "test_password")
|
||||
clients.append(client)
|
||||
|
||||
# All should be connected
|
||||
for client in clients:
|
||||
assert client.token is not None
|
||||
|
||||
# Each can create games independently
|
||||
for i, client in enumerate(clients):
|
||||
game_data = await client.create_game(
|
||||
power_name="FRANCE" if i == 0 else None, n_controls=1
|
||||
)
|
||||
assert game_data["game_id"] is not None
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
for client in clients:
|
||||
try:
|
||||
await client.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.websocket
|
||||
async def test_rapid_message_sending(self, client_with_game):
|
||||
"""Test sending multiple messages rapidly."""
|
||||
# Send multiple order updates rapidly
|
||||
for i in range(5):
|
||||
orders = [f"A PAR H # Iteration {i}"]
|
||||
await client_with_game.set_orders("FRANCE", orders)
|
||||
|
||||
# All should succeed without errors
|
||||
Loading…
Add table
Add a link
Reference in a new issue