AI_Diplomacy/diplomacy/utils/sorted_dict.py
2025-02-06 14:33:10 -08:00

267 lines
11 KiB
Python

# ==============================================================================
# Copyright (C) 2019 - Philip Paquette, Steven Bocco
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License along
# with this program. If not, see <https://www.gnu.org/licenses/>.
# ==============================================================================
""" Helper class to provide a dict with sorted keys. """
from diplomacy.utils.common import is_dictionary
from diplomacy.utils.sorted_set import SortedSet
class SortedDict:
""" Dict with sorted keys. """
__slots__ = ['__val_type', '__keys', '__couples']
def __init__(self, key_type, val_type, kwargs=None):
""" Initialize a typed SortedDict.
:param key_type: expected type for keys.
:param val_type: expected type for values.
:param kwargs: (optional) dictionary-like object: initial values for sorted dict.
"""
self.__val_type = val_type
self.__keys = SortedSet(key_type)
self.__couples = {}
if kwargs is not None:
assert is_dictionary(kwargs)
for key, value in kwargs.items():
self.put(key, value)
@staticmethod
def builder(key_type, val_type):
""" Return a function to build sorted dicts from a dictionary-like object.
Returned function expects a dictionary parameter (an object with method items()).
.. code-block:: python
builder_fn = SortedDict.builder(str, int)
my_sorted_dict = builder_fn({'a': 1, 'b': 2})
:param key_type: expected type for keys.
:param val_type: expected type for values.
:return: callable
"""
return lambda dictionary: SortedDict(key_type, val_type, dictionary)
@property
def key_type(self):
""" Get key type. """
return self.__keys.element_type
@property
def val_type(self):
""" Get value type. """
return self.__val_type
def __str__(self):
return 'SortedDict{%s}' % ', '.join('%s:%s' % (k, self.__couples[k]) for k in self.__keys)
def __bool__(self):
return bool(self.__keys)
def __len__(self):
return len(self.__keys)
def __eq__(self, other):
""" Return True if self and other are equal.
Note that self and other must also have same key and value types.
"""
assert isinstance(other, SortedDict)
return (self.key_type is other.key_type
and self.val_type is other.val_type
and len(self) == len(other)
and all(key in other and self[key] == other[key] for key in self.__keys))
def __getitem__(self, key):
return self.__couples[key]
def __setitem__(self, key, value):
self.put(key, value)
def __delitem__(self, key):
self.remove(key)
def __iter__(self):
return self.__keys.__iter__()
def __contains__(self, key):
return key in self.__couples
def get(self, key, default=None):
""" Return value associated with key, or default value if key not found. """
return self.__couples.get(key, default)
def put(self, key, value):
""" Add a key with a value to the dict. """
if not isinstance(value, self.__val_type):
raise TypeError('Expected value type %s, got %s' % (self.__val_type, type(value)))
if key not in self.__keys:
self.__keys.add(key)
self.__couples[key] = value
def remove(self, key):
""" Pop (remove and return) value associated with given key, or None if key not found. """
if key in self.__couples:
self.__keys.remove(key)
return self.__couples.pop(key, None)
def first_key(self):
""" Get the lowest key from the dict. """
return self.__keys[0]
def first_value(self):
""" Get the value associated to lowest key in the dict. """
return self.__couples[self.__keys[0]]
def last_key(self):
""" Get the highest key from the dict. """
return self.__keys[-1]
def last_value(self):
""" Get the value associated to highest key in the dict. """
return self.__couples[self.__keys[-1]]
def last_item(self):
""" Get the item (key-value pair) for the highest key in the dict. """
return self.__keys[-1], self.__couples[self.__keys[-1]]
def keys(self):
""" Get an iterator to the keys in the dict. """
return iter(self.__keys)
def values(self):
""" Get an iterator to the values in the dict. """
return (self.__couples[k] for k in self.__keys)
def reversed_values(self):
""" Get an iterator to the values in the dict in reversed order or keys. """
return (self.__couples[k] for k in reversed(self.__keys))
def items(self):
""" Get an iterator to the items in the dict. """
return ((k, self.__couples[k]) for k in self.__keys)
def reversed_items(self):
""" Get an iterator to the items in the dict in reversed order of keys. """
return ((k, self.__couples[k]) for k in reversed(self.__keys))
def sub_keys(self, key_from=None, key_to=None):
""" Return list of keys between key_from and key_to (both bounds included). """
position_from, position_to = self._get_keys_interval(key_from, key_to)
return self.__keys[position_from:(position_to + 1)]
def sub(self, key_from=None, key_to=None):
""" Return a list of values associated to keys between key_from and key_to
(both bounds included).
- If key_from is None, lowest key in dict is used.
- If key_to is None, greatest key in dict is used.
- If key_from is not in dict, lowest key in dict greater than key_from is used.
- If key_to is not in dict, greatest key in dict less than key_to is used.
- If dict is empty, return empty list.
- With keys (None, None) return a copy of all values.
- With keys (None, key_to), return values from first to the one associated to key_to.
- With keys (key_from, None), return values from the one associated to key_from to the last value.
:param key_from: start key
:param key_to: end key
:return: list: values in closed keys interval [key_from; key_to]
"""
position_from, position_to = self._get_keys_interval(key_from, key_to)
return [self.__couples[k] for k in self.__keys[position_from:(position_to + 1)]]
def remove_sub(self, key_from=None, key_to=None):
""" Remove values associated to keys between key_from and key_to (both bounds included).
See sub() doc about key_from and key_to.
:param key_from: start key
:param key_to: end key
:return: nothing
"""
position_from, position_to = self._get_keys_interval(key_from, key_to)
keys_to_remove = self.__keys[position_from:(position_to + 1)]
for key in keys_to_remove:
self.remove(key)
def key_from_index(self, index):
""" Return key matching given position in sorted dict, or None for invalid position. """
return self.__keys[index] if -len(self.__keys) <= index < len(self.__keys) else None
def get_previous_key(self, key):
""" Return greatest key lower than given key, or None if not exists. """
return self.__keys.get_previous_value(key)
def get_next_key(self, key):
""" Return smallest key greater then given key, or None if not exists. """
return self.__keys.get_next_value(key)
def _get_keys_interval(self, key_from, key_to):
""" Get a couple of internal key positions (index of key_from, index of key_to) allowing
to easily retrieve values in closed interval [index of key_from; index of key_to]
corresponding to Python slice [index of key_from : (index of key_to + 1)]
- If dict is empty, return (0, -1), so that python slice [0 : -1 + 1] corresponds to empty interval.
- If key_from is None, lowest key in dict is used.
- If key_to is None, greatest key in dict is used.
- If key_from is not in dict, lowest key in dict greater than key_from is used.
- If key_to is not in dict, greatest key in dict less than key_to is used.
- With keys (None, None), we get interval of all values.
- With keys (key_from, None), we get interval for values from key_from to the last key.
- With keys (None, key_to), we get interval for values from the first key to key_to.
:param key_from: start key
:param key_to: end key
:return: (int, int): couple of integers: (index of key_from, index of key_to).
"""
if not self:
return 0, -1
if key_from is not None and key_from not in self.__couples:
key_from = self.__keys.get_next_value(key_from)
if key_from is None:
return 0, -1
if key_to is not None and key_to not in self.__couples:
key_to = self.__keys.get_previous_value(key_to)
if key_to is None:
return 0, -1
if key_from is None and key_to is None:
key_from = self.first_key()
key_to = self.last_key()
elif key_from is not None and key_to is None:
key_to = self.last_key()
elif key_from is None and key_to is not None:
key_from = self.first_key()
if key_from > key_to:
raise IndexError('expected key_from <= key_to (%s vs %s)' % (key_from, key_to))
position_from = self.__keys.index(key_from)
position_to = self.__keys.index(key_to)
assert position_from is not None and position_to is not None
return position_from, position_to
def clear(self):
""" Remove all items from dict. """
self.__couples.clear()
self.__keys.clear()
def fill(self, dct):
""" Add given dict to this sorted dict. """
if dct:
assert is_dictionary(dct)
for key, value in dct.items():
self.put(key, value)
def copy(self):
""" Return a copy of this sorted dict. """
return SortedDict(self.__keys.element_type, self.__val_type, self.__couples)