WIP: Continuing to add pieces for the LMs to play over websocket.

Signed-off-by: Tyler Marques <me@tylermarques.com>
This commit is contained in:
Tyler Marques 2025-06-22 09:26:55 -07:00
parent 939da088a5
commit cc7fb1ccd3
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
11 changed files with 1101 additions and 107 deletions

View file

@ -0,0 +1 @@
# Test package initialization

View file

@ -0,0 +1,123 @@
"""
Pytest configuration and fixtures for WebSocket testing.
This module provides shared fixtures and utilities for testing our
WebSocket client implementations against a fake server.
"""
# Add parent directory to path for ai_diplomacy imports (runtime only)
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from typing import AsyncGenerator, Dict, Generator
import pytest
import pytest_asyncio
# Add the bot_client directory to the path so we can import our modules
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from .test_fake_websocket_server import FakeServerManager, FakeWebSocketServer
from typed_websocket_client import TypedWebSocketDiplomacyClient
@pytest_asyncio.fixture(scope="function", autouse=True)
async def fake_server() -> AsyncGenerator[FakeWebSocketServer, None]:
"""
Fixture that provides a running fake WebSocket server for testing.
The server is automatically started before the test and stopped after.
Uses port 8433 to avoid conflicts with real servers on 8432.
"""
async with FakeServerManager("localhost", 8433) as server:
yield server
@pytest.fixture
def credentials() -> Generator[Dict[str, str], None]:
yield {"username": "test_user", "password": "test_password"}
@pytest_asyncio.fixture
async def client() -> AsyncGenerator[TypedWebSocketDiplomacyClient, None]:
"""
Fixture that provides a TypedWebSocketDiplomacyClient instance.
The client is configured to connect to the fake server on port 8433.
"""
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
yield client
# Cleanup: close the client connection
try:
await client.close()
except:
pass # Ignore cleanup errors
@pytest.fixture
async def authenticated_client(fake_server, client):
"""
Fixture that provides an authenticated TypedWebSocketDiplomacyClient.
This client is already connected and authenticated, ready for testing
game operations.
"""
await client.connect()
token = await client.authenticate("test_user", "test_password")
assert token is not None
assert client.token == token
yield client
@pytest.fixture
async def client_with_game(authenticated_client):
"""
Fixture that provides an authenticated client with a created game.
The client is in a game as FRANCE, ready for order submission and
game interaction testing.
"""
game_data = await authenticated_client.create_game(
power_name="FRANCE",
n_controls=1, # For testing
)
assert game_data is not None
assert authenticated_client.game_id is not None
yield authenticated_client
# Test utilities
class TestHelpers:
"""Helper methods for testing WebSocket interactions."""
@staticmethod
def assert_valid_token(token: str):
"""Assert that a token looks valid."""
assert isinstance(token, str)
assert len(token) > 10
assert "fake_token_" in token
@staticmethod
def assert_valid_game_data(game_data: dict):
"""Assert that game data has expected structure."""
assert isinstance(game_data, dict)
assert "game_id" in game_data
assert "phase" in game_data
assert "powers" in game_data
assert "map_name" in game_data
@staticmethod
def assert_valid_games_list(games: list):
"""Assert that games list has expected structure."""
assert isinstance(games, list)
for game in games:
assert "game_id" in game
assert "status" in game
assert "phase" in game
@pytest.fixture
def helpers():
"""Fixture that provides test helper methods."""
return TestHelpers

View file

@ -0,0 +1,378 @@
"""
Fake WebSocket Server for Testing
This module provides a mock Diplomacy WebSocket server that implements the protocol
from WEBSOCKET.md for testing purposes. It responds to typed messages with appropriate
responses, allowing us to test our client implementation.
"""
import asyncio
import json
import uuid
import websockets
from typing import Dict, Any, Optional, Set
from loguru import logger
from models import (
# Request types
SignInRequest,
CreateGameRequest,
JoinGameRequest,
ListGamesRequest,
SetOrdersRequest,
ProcessGameRequest,
GetAllPossibleOrdersRequest,
# Response types
DataTokenResponse,
DataGameResponse,
DataGamesResponse,
DataPossibleOrdersResponse,
OkResponse,
ErrorResponse,
# Notifications
parse_message,
serialize_message,
)
class FakeWebSocketServer:
"""
A fake WebSocket server that implements the Diplomacy protocol for testing.
This server maintains minimal state and responds to requests with valid
responses according to the WEBSOCKET.md protocol specification.
"""
def __init__(self, host: str = "localhost", port: int = 8433):
self.host = host
self.port = port
self.server = None
self.connected_clients: Set[websockets.WebSocketServerProtocol] = set()
# Mock server state
self.users = {
"test_user": "test_password",
"ai_player": "password",
"player1": "password"
}
self.tokens: Dict[str, str] = {} # token -> username
self.games: Dict[str, Dict[str, Any]] = {}
self.game_counter = 1
async def start(self):
"""Start the fake WebSocket server."""
logger.info(f"Starting fake WebSocket server on {self.host}:{self.port}")
self.server = await websockets.serve(
self.handle_client,
self.host,
self.port
)
logger.info("Fake WebSocket server started")
async def stop(self):
"""Stop the fake WebSocket server."""
if self.server:
self.server.close()
await self.server.wait_closed()
logger.info("Fake WebSocket server stopped")
async def handle_client(self, websocket):
"""Handle a new WebSocket client connection."""
self.connected_clients.add(websocket)
logger.info(f"Client connected from {websocket.remote_address}")
try:
async for message in websocket:
await self.handle_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"Error handling client: {e}")
finally:
self.connected_clients.discard(websocket)
async def handle_message(self, websocket, message_text: str):
"""Handle an incoming message from a client."""
try:
# Parse the raw message
data = json.loads(message_text)
request = parse_message(data)
# Generate response based on request type
response = await self.generate_response(request)
# Send response back to client
if response:
response_data = serialize_message(response)
await websocket.send(json.dumps(response_data))
logger.debug(f"Sent response: {response.name}")
except Exception as e:
logger.error(f"Error processing message: {e}")
logger.debug(f"Raw message: {message_text}")
# Send error response if we can extract request_id
try:
data = json.loads(message_text)
request_id = data.get("request_id", str(uuid.uuid4()))
error_response = ErrorResponse(
request_id=request_id,
error_type="PARSING_ERROR",
message=str(e)
)
error_data = serialize_message(error_response)
await websocket.send(json.dumps(error_data))
except:
logger.error("Could not send error response")
async def generate_response(self, request) -> Optional[Any]:
"""Generate appropriate response for a request."""
logger.info(f"Processing request: {request.name}")
# Authentication requests
if isinstance(request, SignInRequest):
return await self.handle_sign_in(request)
# Channel-level requests (require token)
elif isinstance(request, CreateGameRequest):
return await self.handle_create_game(request)
elif isinstance(request, JoinGameRequest):
return await self.handle_join_game(request)
elif isinstance(request, ListGamesRequest):
return await self.handle_list_games(request)
# Game-level requests (require token + game context)
elif isinstance(request, SetOrdersRequest):
return await self.handle_set_orders(request)
elif isinstance(request, ProcessGameRequest):
return await self.handle_process_game(request)
elif isinstance(request, GetAllPossibleOrdersRequest):
return await self.handle_get_possible_orders(request)
else:
logger.warning(f"Unhandled request type: {type(request)}")
return ErrorResponse(
request_id=request.request_id,
error_type="UNSUPPORTED_REQUEST",
message=f"Request type {request.name} not supported by fake server"
)
async def handle_sign_in(self, request: SignInRequest) -> Any:
"""Handle authentication request."""
username = request.username
password = request.password
if username in self.users and self.users[username] == password:
# Generate auth token
token = f"fake_token_{uuid.uuid4().hex[:16]}"
self.tokens[token] = username
return DataTokenResponse(
request_id=request.request_id,
data=token
)
else:
return ErrorResponse(
request_id=request.request_id,
error_type="AUTHENTICATION_ERROR",
message="Invalid username or password"
)
async def handle_create_game(self, request: CreateGameRequest) -> Any:
"""Handle game creation request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Create a new game
game_id = f"GAME_{self.game_counter:04d}"
self.game_counter += 1
game_data = {
"game_id": game_id,
"map_name": request.map_name,
"rules": request.rules,
"phase": "S1901M",
"status": "FORMING",
"n_controls": request.n_controls,
"powers": {
"AUSTRIA": {"units": [], "centers": ["VIE", "BUD", "TRI"], "is_eliminated": False},
"ENGLAND": {"units": [], "centers": ["EDI", "LVP", "LON"], "is_eliminated": False},
"FRANCE": {"units": [], "centers": ["PAR", "BRE", "MAR"], "is_eliminated": False},
"GERMANY": {"units": [], "centers": ["BER", "MUN", "KIE"], "is_eliminated": False},
"ITALY": {"units": [], "centers": ["ROM", "NAP", "VEN"], "is_eliminated": False},
"RUSSIA": {"units": [], "centers": ["MOS", "SEV", "STP", "WAR"], "is_eliminated": False},
"TURKEY": {"units": [], "centers": ["ANK", "CON", "SMY"], "is_eliminated": False},
},
"controlled_powers": {}
}
# If a specific power was requested, assign it
if request.power_name:
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
self.games[game_id] = game_data
return DataGameResponse(
request_id=request.request_id,
data=game_data
)
async def handle_join_game(self, request: JoinGameRequest) -> Any:
"""Handle game join request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
game_data = self.games[game_id].copy()
# If a specific power was requested, assign it
if request.power_name:
game_data["controlled_powers"][request.power_name] = self.tokens[request.token]
self.games[game_id] = game_data
return DataGameResponse(
request_id=request.request_id,
data=game_data
)
async def handle_list_games(self, request: ListGamesRequest) -> Any:
"""Handle list games request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Return simplified game info
games_list = []
for game_id, game_data in self.games.items():
games_list.append({
"game_id": game_id,
"map_name": game_data["map_name"],
"status": game_data["status"],
"phase": game_data["phase"],
"n_controls": game_data["n_controls"]
})
return DataGamesResponse(
request_id=request.request_id,
data=games_list
)
async def handle_set_orders(self, request: SetOrdersRequest) -> Any:
"""Handle set orders request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
# In a real server, we'd validate the orders and store them
# For testing, we just acknowledge receipt
logger.info(f"Orders received for {request.game_role}: {request.orders}")
return OkResponse(request_id=request.request_id)
async def handle_process_game(self, request: ProcessGameRequest) -> Any:
"""Handle process game request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
game_id = request.game_id
if game_id not in self.games:
return ErrorResponse(
request_id=request.request_id,
error_type="GAME_NOT_FOUND",
message=f"Game {game_id} not found"
)
# Simulate game processing
game_data = self.games[game_id]
current_phase = game_data["phase"]
# Simple phase progression
if current_phase == "S1901M":
game_data["phase"] = "F1901M"
elif current_phase == "F1901M":
game_data["phase"] = "W1901A"
else:
# For testing, cycle back to start
game_data["phase"] = "S1902M"
self.games[game_id] = game_data
# Send notification to all clients (in real implementation)
# For testing, we'll just return OK
return OkResponse(request_id=request.request_id)
async def handle_get_possible_orders(self, request: GetAllPossibleOrdersRequest) -> Any:
"""Handle get possible orders request."""
if not self.validate_token(request.token):
return self.create_auth_error(request.request_id)
# Return mock possible orders
possible_orders = {
"PAR": ["A PAR H", "A PAR - BUR", "A PAR - PIC", "A PAR - GAS"],
"BRE": ["F BRE H", "F BRE - MAO", "F BRE - ENG", "F BRE - PIC"],
"MAR": ["A MAR H", "A MAR - GAS", "A MAR - SPA", "A MAR - PIE"],
}
return DataPossibleOrdersResponse(
request_id=request.request_id,
data=possible_orders
)
def validate_token(self, token: str) -> bool:
"""Validate an authentication token."""
return token in self.tokens
def create_auth_error(self, request_id: str) -> ErrorResponse:
"""Create a standard authentication error response."""
return ErrorResponse(
request_id=request_id,
error_type="AUTHENTICATION_ERROR",
message="Invalid or missing authentication token"
)
class FakeServerManager:
"""Context manager for the fake WebSocket server."""
def __init__(self, host: str = "localhost", port: int = 8433):
self.server = FakeWebSocketServer(host, port)
async def __aenter__(self):
await self.server.start()
# Give the server a moment to start
await asyncio.sleep(0.1)
return self.server
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.server.stop()
# Utility function for tests
async def run_fake_server(host: str = "localhost", port: int = 8433):
"""Run the fake server (for manual testing)."""
async with FakeServerManager(host, port) as server:
logger.info(f"Fake server running on {host}:{port}")
logger.info("Press Ctrl+C to stop")
try:
# Keep running until interrupted
await asyncio.Event().wait()
except KeyboardInterrupt:
logger.info("Server stopped by user")
if __name__ == "__main__":
# Run the fake server for manual testing
asyncio.run(run_fake_server())

View file

@ -0,0 +1,324 @@
"""
Tests for WebSocket client functionality using the fake server.
These tests verify that our typed WebSocket client can properly:
1. Connect to a WebSocket server
2. Send and receive typed messages according to the protocol
3. Handle authentication, game creation, and game operations
4. Validate message format and content
"""
from typing import Dict
import pytest
import uuid
from models import (
SignInRequest,
CreateGameRequest,
SetOrdersRequest,
serialize_message,
parse_message,
)
from typed_websocket_client import TypedWebSocketDiplomacyClient
class TestAuthentication:
"""Test authentication flow and message handling."""
@pytest.mark.websocket
async def test_successful_authentication(
self, client: TypedWebSocketDiplomacyClient, helpers
):
"""Test successful authentication with valid credentials."""
await client.connect_and_authenticate("test_user", "test_password")
helpers.assert_valid_token(client.token)
@pytest.mark.websocket
async def test_failed_authentication(self, client):
"""Test authentication failure with invalid credentials."""
await client.connect()
with pytest.raises(ValueError, match="Authentication failed"):
await client.authenticate("invalid_user", "wrong_password")
assert client.token is None
@pytest.mark.websocket
async def test_authentication_message_format(
self, credentials: Dict[str, str], client: TypedWebSocketDiplomacyClient
):
"""Test that authentication messages are properly formatted."""
await client.connect_and_authenticate(**credentials)
# Create a sign-in request manually to test message format
request = SignInRequest(
request_id=str(uuid.uuid4()), username="test_user", password="test_password"
)
# Verify the request serializes correctly
message_data = serialize_message(request)
assert message_data["name"] == "sign_in"
assert message_data["username"] == "test_user"
assert message_data["password"] == "test_password"
assert "request_id" in message_data
# Test that we can parse it back
parsed = parse_message(message_data)
assert isinstance(parsed, SignInRequest)
assert parsed.username == "test_user"
class TestGameOperations:
"""Test game creation, joining, and basic operations."""
@pytest.mark.websocket
async def test_create_game(self, authenticated_client, helpers):
"""Test creating a new game."""
game_data = await authenticated_client.create_game(
power_name="FRANCE", n_controls=1
)
helpers.assert_valid_game_data(game_data)
assert authenticated_client.game_id is not None
assert authenticated_client.game_role == "FRANCE"
@pytest.mark.websocket
async def test_list_games(self, authenticated_client, helpers):
"""Test listing available games."""
# First create a game so there's something to list
await authenticated_client.create_game(power_name="ENGLAND", n_controls=1)
games = await authenticated_client.list_games()
helpers.assert_valid_games_list(games)
assert len(games) >= 1
# Check that our created game is in the list
game_ids = [game["game_id"] for game in games]
assert authenticated_client.game_id in game_ids
@pytest.mark.websocket
async def test_join_existing_game(
self, client: TypedWebSocketDiplomacyClient, helpers
):
"""Test joining an existing game."""
# Connect and authenticate a first client
await client.connect()
await client.authenticate("test_user", "test_password")
# Create a game
game_data = await client.create_game(power_name="FRANCE", n_controls=1)
game_id = game_data["game_id"]
# Create a second client to join the game
client2 = client.__class__("localhost", 8433, use_ssl=False)
try:
await client2.connect()
await client2.authenticate("ai_player", "password")
# Join the existing game as a different power
joined_game_data = await client2.join_game(game_id, power_name="ENGLAND")
helpers.assert_valid_game_data(joined_game_data)
assert joined_game_data["game_id"] == game_id
assert client2.game_role == "ENGLAND"
finally:
await client2.close()
@pytest.mark.websocket
async def test_join_nonexistent_game(self, authenticated_client):
"""Test joining a game that doesn't exist."""
with pytest.raises(ValueError, match="Game .* not found"):
await authenticated_client.join_game(
"NONEXISTENT_GAME", power_name="FRANCE"
)
class TestGamePlay:
"""Test actual gameplay operations like setting orders and processing."""
@pytest.mark.websocket
async def test_set_orders(self, client_with_game):
"""Test submitting orders for a power."""
orders = ["A PAR H", "F BRE H", "A MAR H"]
# Should not raise an exception
await client_with_game.set_orders("FRANCE", orders)
@pytest.mark.websocket
async def test_set_empty_orders(self, client_with_game):
"""Test submitting empty orders."""
await client_with_game.set_orders("FRANCE", [])
@pytest.mark.websocket
async def test_process_game(self, client_with_game):
"""Test processing the game phase."""
# Should not raise an exception (fake server allows processing)
await client_with_game.process_game()
@pytest.mark.websocket
async def test_get_possible_orders(self, client_with_game):
"""Test getting possible orders for the current phase."""
possible_orders = await client_with_game.get_all_possible_orders()
assert isinstance(possible_orders, dict)
# Fake server returns some mock orders
assert len(possible_orders) > 0
# Check format: location -> list of orders
for location, orders in possible_orders.items():
assert isinstance(location, str)
assert isinstance(orders, list)
for order in orders:
assert isinstance(order, str)
class TestMessageValidation:
"""Test message format validation and error handling."""
@pytest.mark.websocket
async def test_create_game_message_format(self, authenticated_client):
"""Test that create game messages have correct format."""
request = CreateGameRequest(
request_id=str(uuid.uuid4()),
token=authenticated_client.token,
map_name="standard",
rules=["NO_PRESS", "IGNORE_ERRORS"],
power_name="AUSTRIA",
n_controls=7,
)
message_data = serialize_message(request)
# Verify all required fields are present
assert message_data["name"] == "create_game"
assert message_data["token"] == authenticated_client.token
assert message_data["map_name"] == "standard"
assert message_data["rules"] == ["NO_PRESS", "IGNORE_ERRORS"]
assert message_data["power_name"] == "AUSTRIA"
assert message_data["n_controls"] == 7
assert "request_id" in message_data
@pytest.mark.websocket
async def test_set_orders_message_format(self, client_with_game):
"""Test that set orders messages have correct format."""
request = SetOrdersRequest(
request_id=str(uuid.uuid4()),
token=client_with_game.token,
game_id=client_with_game.game_id,
game_role="FRANCE",
orders=["A PAR H", "F BRE - ENG"],
)
message_data = serialize_message(request)
# Verify all required fields are present
assert message_data["name"] == "set_orders"
assert message_data["token"] == client_with_game.token
assert message_data["game_id"] == client_with_game.game_id
assert message_data["game_role"] == "FRANCE"
assert message_data["orders"] == ["A PAR H", "F BRE - ENG"]
assert "request_id" in message_data
@pytest.mark.websocket
async def test_error_response_handling(self, client):
"""Test that error responses are properly handled."""
await client.connect()
# Try to authenticate with wrong credentials
with pytest.raises(ValueError) as exc_info:
await client.authenticate("bad_user", "bad_pass")
assert "Authentication failed" in str(exc_info.value)
@pytest.mark.websocket
async def test_unauthenticated_requests_fail(self, client):
"""Test that requests without authentication fail appropriately."""
await client.connect()
# Don't authenticate
with pytest.raises(ValueError, match="Must authenticate first"):
await client.create_game()
@pytest.mark.websocket
async def test_message_round_trip(self, authenticated_client):
"""Test that messages can be serialized and parsed correctly."""
# Test various message types
messages = [
SignInRequest(
request_id=str(uuid.uuid4()), username="test", password="pass"
),
CreateGameRequest(
request_id=str(uuid.uuid4()),
token="test_token",
map_name="standard",
power_name="FRANCE",
),
SetOrdersRequest(
request_id=str(uuid.uuid4()),
token="test_token",
game_id="TEST_GAME",
game_role="FRANCE",
orders=["A PAR H"],
),
]
for original_msg in messages:
# Serialize to dict
serialized = serialize_message(original_msg)
# Parse back to object
parsed_msg = parse_message(serialized)
# Should be the same type and have same data
assert type(parsed_msg) == type(original_msg)
assert parsed_msg.name == original_msg.name
assert parsed_msg.request_id == original_msg.request_id
class TestConcurrentOperations:
"""Test concurrent WebSocket operations."""
@pytest.mark.websocket
@pytest.mark.slow
async def test_multiple_concurrent_clients(self, fake_server):
"""Test multiple clients connecting simultaneously."""
clients = []
try:
# Create multiple clients
for i in range(3):
client = TypedWebSocketDiplomacyClient("localhost", 8433, use_ssl=False)
await client.connect()
await client.authenticate("test_user", "test_password")
clients.append(client)
# All should be connected
for client in clients:
assert client.token is not None
# Each can create games independently
for i, client in enumerate(clients):
game_data = await client.create_game(
power_name="FRANCE" if i == 0 else None, n_controls=1
)
assert game_data["game_id"] is not None
finally:
# Cleanup
for client in clients:
try:
await client.close()
except:
pass
@pytest.mark.websocket
async def test_rapid_message_sending(self, client_with_game):
"""Test sending multiple messages rapidly."""
# Send multiple order updates rapidly
for i in range(5):
orders = [f"A PAR H # Iteration {i}"]
await client_with_game.set_orders("FRANCE", orders)
# All should succeed without errors