mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
INIT
This commit is contained in:
parent
e8530a146d
commit
93c073e2df
295 changed files with 86794 additions and 0 deletions
479
diplomacy/daide/tests/test_daide_game.py
Normal file
479
diplomacy/daide/tests/test_daide_game.py
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
# ==============================================================================
|
||||
# Copyright (C) 2019 - Philip Paquette
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it under
|
||||
# the terms of the GNU Affero General Public License as published by the Free
|
||||
# Software Foundation, either version 3 of the License, or (at your option) any
|
||||
# later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
|
||||
# details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License along
|
||||
# with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
# ==============================================================================
|
||||
""" Tests for complete DAIDE games """
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
|
||||
from tornado import gen
|
||||
from tornado.concurrent import chain_future, Future
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado.tcpclient import TCPClient
|
||||
|
||||
from diplomacy import Server
|
||||
from diplomacy.daide import messages, tokens
|
||||
from diplomacy.daide.tokens import Token
|
||||
from diplomacy.daide.utils import str_to_bytes, bytes_to_str
|
||||
from diplomacy.server.server import is_port_opened
|
||||
from diplomacy.server.server_game import ServerGame
|
||||
from diplomacy.client.connection import connect
|
||||
from diplomacy.utils import common, constants, strings
|
||||
|
||||
# Constants
|
||||
LOGGER = logging.getLogger('diplomacy.daide.tests.test_daide_game')
|
||||
HOSTNAME = 'localhost'
|
||||
FILE_FOLDER_NAME = os.path.abspath(os.path.dirname(__file__))
|
||||
BOT_KEYWORD = '__bot__'
|
||||
|
||||
# Named Tuples
|
||||
DaideComm = namedtuple('DaideComm', ['client_id', 'request', 'resp_notifs'])
|
||||
ClientRequest = namedtuple('ClientRequest', ['client', 'request'])
|
||||
|
||||
# Adapted from: https://stackoverflow.com/questions/492519/timeout-on-a-function-call
|
||||
def run_with_timeout(callable_fn, timeout):
|
||||
""" Raises an error on timeout """
|
||||
def handler(signum, frame):
|
||||
""" Raises a timeout """
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(timeout)
|
||||
try:
|
||||
return callable_fn()
|
||||
except TimeoutError as exc:
|
||||
raise exc
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
class ClientCommsSimulator:
|
||||
""" Represents a client's comms """
|
||||
def __init__(self, client_id):
|
||||
""" Constructor
|
||||
|
||||
:param client_id: the id
|
||||
"""
|
||||
self._id = client_id
|
||||
self._stream = None
|
||||
self._is_game_joined = False
|
||||
self._comms = False
|
||||
|
||||
@property
|
||||
def stream(self):
|
||||
""" Returns the stream """
|
||||
return self._stream
|
||||
|
||||
@property
|
||||
def comms(self):
|
||||
""" Returns the comms """
|
||||
return self._comms
|
||||
|
||||
@property
|
||||
def is_game_joined(self):
|
||||
""" Returns if the client has joinded the game """
|
||||
return self._is_game_joined
|
||||
|
||||
def set_comms(self, comms):
|
||||
""" Set the client's communications.
|
||||
|
||||
The client's comms will be sorted to have the requests of a phase
|
||||
preceding the responses / notifications of the phase
|
||||
|
||||
:param comms: the game's communications
|
||||
"""
|
||||
self._comms = [comm for comm in comms if comm.client_id == self._id]
|
||||
|
||||
comm_idx = 0
|
||||
while comm_idx < len(self._comms):
|
||||
comm = self._comms[comm_idx]
|
||||
|
||||
# Find the request being right after a synchonization point (TME notification)
|
||||
if not comm.request:
|
||||
comm_idx += 1
|
||||
continue
|
||||
|
||||
# Next communication to sort
|
||||
next_comm_idx = comm_idx + 1
|
||||
while next_comm_idx < len(self._comms):
|
||||
next_comm = self._comms[next_comm_idx]
|
||||
|
||||
# Group the request at the beginning of the communications in the phase
|
||||
if next_comm.request:
|
||||
comm_idx += 1
|
||||
self._comms.insert(comm_idx, self._comms.pop(next_comm_idx))
|
||||
|
||||
# Synchonization point is a TME notif as it marks the beginning of a phase
|
||||
if any(resp_notif.startswith('TME') for resp_notif in next_comm.resp_notifs):
|
||||
break
|
||||
|
||||
next_comm_idx += 1
|
||||
|
||||
comm_idx += 1
|
||||
|
||||
def pop_next_request(self, comms):
|
||||
""" Pop the next request from a DAIDE communications list
|
||||
|
||||
:return: The next request along with the updated list of communications
|
||||
or None and the updated list of communications
|
||||
"""
|
||||
com = next(iter(comms), None)
|
||||
request = None
|
||||
|
||||
while com and com.client_id == self._id:
|
||||
if com.request:
|
||||
request = com.request
|
||||
comms[0] = DaideComm(com.client_id, '', com.resp_notifs)
|
||||
LOGGER.info('[%d:%d] preparing to send request [%s]', self._id, self.stream.socket.fileno()+1, request)
|
||||
break
|
||||
elif com.resp_notifs:
|
||||
break
|
||||
else:
|
||||
comms.pop(0)
|
||||
com = next(iter(comms), None)
|
||||
|
||||
return request, comms
|
||||
|
||||
def pop_next_resp_notif(self, comms):
|
||||
""" Pop the next response or notifcation from a DAIDE communications list
|
||||
|
||||
:return: The next response or notifcation along with the updated list of communications
|
||||
or None and the updated list of communications
|
||||
"""
|
||||
com = next(iter(comms), None)
|
||||
resp_notif = None
|
||||
|
||||
while com and com.client_id == self._id:
|
||||
if com.request:
|
||||
break
|
||||
elif com.resp_notifs:
|
||||
resp_notif = com.resp_notifs.pop(0)
|
||||
LOGGER.info('[%d:%d] waiting for resp_notif [%s]', self._id, self.stream.socket.fileno()+1, resp_notif)
|
||||
break
|
||||
else:
|
||||
comms.pop(0)
|
||||
com = next(iter(comms), None)
|
||||
|
||||
return resp_notif, comms
|
||||
|
||||
@gen.coroutine
|
||||
def connect(self, game_port):
|
||||
""" Connect to the DAIDE server
|
||||
|
||||
:param game_port: the DAIDE game's port
|
||||
"""
|
||||
self._stream = yield TCPClient().connect('localhost', game_port)
|
||||
LOGGER.info('Connected to %d', game_port)
|
||||
message = messages.InitialMessage()
|
||||
yield self._stream.write(bytes(message))
|
||||
yield messages.DaideMessage.from_stream(self._stream)
|
||||
|
||||
@gen.coroutine
|
||||
def send_request(self, request):
|
||||
""" Sends a request
|
||||
|
||||
:param request: the request to send
|
||||
"""
|
||||
message = messages.DiplomacyMessage()
|
||||
message.content = str_to_bytes(request)
|
||||
yield self._stream.write(bytes(message))
|
||||
|
||||
@gen.coroutine
|
||||
def validate_resp_notifs(self, expected_resp_notifs):
|
||||
""" Validate that expected response / notifications are received regardless of the order
|
||||
|
||||
:param expected_resp_notifs: the response / notifications to receive
|
||||
"""
|
||||
while expected_resp_notifs:
|
||||
resp_notif_message = yield messages.DaideMessage.from_stream(self._stream)
|
||||
|
||||
resp_notif = bytes_to_str(resp_notif_message.content)
|
||||
if Token(from_bytes=resp_notif_message.content[:2]) == tokens.HLO:
|
||||
resp_notif = resp_notif.split(' ')
|
||||
resp_notif[5] = expected_resp_notifs[0].split(' ')[5]
|
||||
resp_notif = ' '.join(resp_notif)
|
||||
self._is_game_joined = True
|
||||
|
||||
LOGGER.info('[%d:%d] Received reply [%s]', self._id, self.stream.socket.fileno() + 1, str(resp_notif))
|
||||
LOGGER.info('[%d:%d] Replies in buffer [%s]', self._id, self.stream.socket.fileno() + 1,
|
||||
','.join(expected_resp_notifs))
|
||||
assert resp_notif in expected_resp_notifs
|
||||
expected_resp_notifs.remove(resp_notif)
|
||||
|
||||
@gen.coroutine
|
||||
def execute_phase(self, game_id, channels):
|
||||
""" Execute a single communications phase
|
||||
|
||||
:param game_id: The game id of the current game
|
||||
:param channels: A dictionary of power name to its channel (BOT_KEYWORD for dummies)
|
||||
:return: True if there are communications left to execute in the game
|
||||
"""
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
try:
|
||||
while self._comms:
|
||||
request, self._comms = self.pop_next_request(self._comms)
|
||||
|
||||
# If request is GOF - Sending empty orders for all human and dummy powers
|
||||
if request and request.split()[0] == 'GOF':
|
||||
|
||||
# Joining all games first
|
||||
games = {}
|
||||
for power_name, channel in channels.items():
|
||||
if power_name == BOT_KEYWORD:
|
||||
all_dummy_power_names = yield channel.get_dummy_waiting_powers(buffer_size=100)
|
||||
for dummy_name in all_dummy_power_names.get(game_id, []):
|
||||
games[dummy_name] = yield channel.join_game(game_id=game_id, power_name=dummy_name)
|
||||
else:
|
||||
games[power_name] = yield channel.join_game(game_id=game_id, power_name=power_name)
|
||||
|
||||
# Submitting orders
|
||||
for power_name, game in games.items():
|
||||
yield game.set_orders(power_name=power_name, orders=[], wait=False)
|
||||
|
||||
# Sending request
|
||||
if request is not None:
|
||||
yield self.send_request(request)
|
||||
|
||||
expected_resp_notifs = []
|
||||
expected_resp_notif, self._comms = self.pop_next_resp_notif(self._comms)
|
||||
|
||||
while expected_resp_notif is not None:
|
||||
expected_resp_notifs.append(expected_resp_notif)
|
||||
# Synchonization point is the request being right after a TME notif or
|
||||
# the next set of responses / notifications
|
||||
if expected_resp_notif.startswith('TME'):
|
||||
break
|
||||
expected_resp_notif, self._comms = self.pop_next_resp_notif(self._comms)
|
||||
|
||||
if expected_resp_notifs:
|
||||
future = self.validate_resp_notifs(expected_resp_notifs)
|
||||
@gen.coroutine
|
||||
def validate_resp_notifs():
|
||||
yield future
|
||||
run_with_timeout(validate_resp_notifs, 1)
|
||||
yield future
|
||||
break
|
||||
|
||||
except StreamClosedError as err:
|
||||
LOGGER.error('Stream closed: %s', err)
|
||||
return False
|
||||
|
||||
return bool(self._comms)
|
||||
|
||||
class ClientsCommsSimulator:
|
||||
""" Represents multi clients's communications """
|
||||
def __init__(self, nb_clients, csv_file, game_id, channels):
|
||||
""" Constructor
|
||||
|
||||
:param nb_clients: the number of clients
|
||||
:param csv_file: the csv containing the communications in chronological order
|
||||
:param game_id: The game id on the server
|
||||
:param channels: A dictionary of power name to its channel (BOT_KEYWORD for dummies)
|
||||
"""
|
||||
with open(csv_file, 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
content = [line.split(',') for line in content.split('\n') if not line.startswith('#')]
|
||||
|
||||
self._game_port = None
|
||||
self._nb_clients = nb_clients
|
||||
self._comms = [DaideComm(int(line[0]), line[1], line[2:]) for line in content if line[0]]
|
||||
self._clients = {}
|
||||
self._game_id = game_id
|
||||
self._channels = channels
|
||||
|
||||
@gen.coroutine
|
||||
def retrieve_game_port(self, host, port):
|
||||
""" Retreive and store the game's port
|
||||
|
||||
:param host: the host
|
||||
:param port: the port
|
||||
:param game_id: the game id
|
||||
"""
|
||||
connection = yield connect(host, port)
|
||||
self._game_port = yield connection.get_daide_port(self._game_id)
|
||||
yield connection.connection.close()
|
||||
|
||||
@gen.coroutine
|
||||
def execute(self):
|
||||
""" Executes the communications between clients """
|
||||
try:
|
||||
# Synchronize clients joining the game
|
||||
while self._comms and (not self._clients
|
||||
or not all(client.is_game_joined for client in self._clients.values())):
|
||||
try:
|
||||
next_comm = next(iter(self._comms)) # type: DaideComm
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if next_comm.client_id not in self._clients and len(self._clients) < self._nb_clients:
|
||||
client = ClientCommsSimulator(next_comm.client_id)
|
||||
yield client.connect(self._game_port)
|
||||
self._clients[next_comm.client_id] = client
|
||||
|
||||
for client in self._clients.values():
|
||||
request, self._comms = client.pop_next_request(self._comms)
|
||||
|
||||
if request is not None:
|
||||
yield client.send_request(request)
|
||||
|
||||
expected_resp_notif, self._comms = client.pop_next_resp_notif(self._comms)
|
||||
|
||||
while expected_resp_notif is not None:
|
||||
yield client.validate_resp_notifs([expected_resp_notif])
|
||||
expected_resp_notif, self._comms = client.pop_next_resp_notif(self._comms)
|
||||
|
||||
except StreamClosedError as err:
|
||||
LOGGER.error('Stream closed: %s', err)
|
||||
|
||||
execution_running = []
|
||||
|
||||
for client in self._clients.values():
|
||||
client.set_comms(self._comms)
|
||||
execution_running.append(client.execute_phase(self._game_id, self._channels))
|
||||
|
||||
execution_running = yield execution_running
|
||||
|
||||
while any(execution_running):
|
||||
execution_running = yield [client.execute_phase(self._game_id, self._channels)
|
||||
for client in self._clients.values()]
|
||||
|
||||
assert all(not client.comms for client in self._clients.values())
|
||||
|
||||
def run_game_data(nb_daide_clients, rules, csv_file):
|
||||
""" Start a server and a client to test DAIDE communications
|
||||
|
||||
:param port: The port of the DAIDE server
|
||||
:param csv_file: the csv file containing the list of DAIDE communications
|
||||
"""
|
||||
server = Server()
|
||||
io_loop = IOLoop()
|
||||
io_loop.make_current()
|
||||
common.Tornado.stop_loop_on_callback_error(io_loop)
|
||||
|
||||
@gen.coroutine
|
||||
def coroutine_func():
|
||||
""" Concrete call to main function. """
|
||||
port = random.randint(9000, 9999)
|
||||
|
||||
while is_port_opened(port, HOSTNAME):
|
||||
port = random.randint(9000, 9999)
|
||||
|
||||
nb_human_players = 1 if nb_daide_clients < 7 else 0
|
||||
|
||||
server.start(port=port)
|
||||
server_game = ServerGame(map_name='standard',
|
||||
n_controls=nb_daide_clients + nb_human_players,
|
||||
rules=rules,
|
||||
server=server)
|
||||
|
||||
# Register game on server.
|
||||
game_id = server_game.game_id
|
||||
server.add_new_game(server_game)
|
||||
server.start_new_daide_server(game_id)
|
||||
|
||||
# Creating human player
|
||||
human_username = 'username'
|
||||
human_password = 'password'
|
||||
|
||||
# Creating bot player to play for dummy powers
|
||||
bot_username = constants.PRIVATE_BOT_USERNAME
|
||||
bot_password = constants.PRIVATE_BOT_PASSWORD
|
||||
|
||||
# Connecting
|
||||
connection = yield connect(HOSTNAME, port)
|
||||
human_channel = yield connection.authenticate(human_username, human_password)
|
||||
bot_channel = yield connection.authenticate(bot_username, bot_password)
|
||||
|
||||
# Joining human to game
|
||||
channels = {BOT_KEYWORD: bot_channel}
|
||||
if nb_human_players:
|
||||
yield human_channel.join_game(game_id=game_id, power_name='AUSTRIA')
|
||||
channels['AUSTRIA'] = human_channel
|
||||
|
||||
comms_simulator = ClientsCommsSimulator(nb_daide_clients, csv_file, game_id, channels)
|
||||
yield comms_simulator.retrieve_game_port(HOSTNAME, port)
|
||||
|
||||
# done_future is only used to prevent pylint E1101 errors on daide_future
|
||||
done_future = Future()
|
||||
daide_future = comms_simulator.execute()
|
||||
chain_future(daide_future, done_future)
|
||||
|
||||
for _ in range(3 + nb_daide_clients):
|
||||
if done_future.done() or server_game.count_controlled_powers() >= (nb_daide_clients + nb_human_players):
|
||||
break
|
||||
yield gen.sleep(2.5)
|
||||
else:
|
||||
raise TimeoutError()
|
||||
|
||||
# Waiting for process to finish
|
||||
while not done_future.done() and server_game.status == strings.ACTIVE:
|
||||
yield gen.sleep(2.5)
|
||||
|
||||
yield daide_future
|
||||
|
||||
try:
|
||||
io_loop.run_sync(coroutine_func)
|
||||
|
||||
finally:
|
||||
server.stop_daide_server(None)
|
||||
if server.backend.http_server:
|
||||
server.backend.http_server.stop()
|
||||
|
||||
io_loop.stop()
|
||||
io_loop.clear_current()
|
||||
io_loop.close()
|
||||
|
||||
server = None
|
||||
Server.__cache__.clear()
|
||||
|
||||
def test_game_reject_map():
|
||||
""" Test a game where the client rejects the map """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_1_reject_map.csv')
|
||||
run_with_timeout(lambda: run_game_data(1, ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
|
||||
def test_game_1():
|
||||
""" Test a complete 1 player game """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_1.csv')
|
||||
run_with_timeout(lambda: run_game_data(1, ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
|
||||
def test_game_history():
|
||||
""" Test a complete 1 player game and validate the full history (except last phase) """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_1_history.csv')
|
||||
run_with_timeout(lambda: run_game_data(1, ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
|
||||
def test_game_7():
|
||||
""" Test a complete 7 players game """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_7.csv')
|
||||
run_with_timeout(lambda: run_game_data(7, ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
|
||||
def test_game_7_draw():
|
||||
""" Test a complete 7 players game that ends with a draw """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_7_draw.csv')
|
||||
run_with_timeout(lambda: run_game_data(7, ['NO_PRESS', 'IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
|
||||
def test_game_7_press():
|
||||
""" Test a complete 7 players game with press """
|
||||
_ = Server() # Initialize cache to prevent timeouts during tests
|
||||
game_path = os.path.join(FILE_FOLDER_NAME, 'game_data_7_press.csv')
|
||||
run_with_timeout(lambda: run_game_data(7, ['IGNORE_ERRORS', 'POWER_CHOICE'], game_path), 60)
|
||||
Loading…
Add table
Add a link
Reference in a new issue