mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Adding PowerEnum as a seperate model
The PowerEnum correctly handles some misspellings. It can be easily expanded to handle more within the _POWER_ALIASES dict.
This commit is contained in:
parent
a241e34496
commit
540c2003e8
8 changed files with 1236 additions and 1053 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -29,6 +29,7 @@ from collections import defaultdict, Counter
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
import statistics
|
import statistics
|
||||||
|
from ..models import PowerEnum
|
||||||
|
|
||||||
class StatisticalGameAnalyzer:
|
class StatisticalGameAnalyzer:
|
||||||
"""Production-ready analyzer for AI Diplomacy game statistics.
|
"""Production-ready analyzer for AI Diplomacy game statistics.
|
||||||
|
|
@ -47,7 +48,6 @@ class StatisticalGameAnalyzer:
|
||||||
'Ally': 2
|
'Ally': 2
|
||||||
}
|
}
|
||||||
|
|
||||||
DIPLOMACY_POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
|
||||||
|
|
||||||
# Complete list of response types found in actual data
|
# Complete list of response types found in actual data
|
||||||
RESPONSE_TYPES = [
|
RESPONSE_TYPES = [
|
||||||
|
|
@ -59,44 +59,8 @@ class StatisticalGameAnalyzer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize analyzer with configuration constants."""
|
"""Initialize analyzer with configuration constants."""
|
||||||
self.relationship_values = self.RELATIONSHIP_VALUES
|
self.relationship_values = self.RELATIONSHIP_VALUES
|
||||||
self.powers = self.DIPLOMACY_POWERS
|
|
||||||
|
|
||||||
def _normalize_recipient_name(self, recipient: str) -> str:
|
|
||||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
|
||||||
if not recipient:
|
|
||||||
return recipient
|
|
||||||
|
|
||||||
recipient = recipient.upper().strip()
|
|
||||||
|
|
||||||
# Handle common LLM typos and abbreviations found in data
|
|
||||||
name_mapping = {
|
|
||||||
'EGMANY': 'GERMANY',
|
|
||||||
'GERMAN': 'GERMANY',
|
|
||||||
'UK': 'ENGLAND',
|
|
||||||
'BRIT': 'ENGLAND',
|
|
||||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
|
||||||
'FRANCE': 'FRANCE', # Keep as-is
|
|
||||||
'GERMANY': 'GERMANY', # Keep as-is
|
|
||||||
'ITALY': 'ITALY', # Keep as-is
|
|
||||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
|
||||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
|
||||||
'TURKEY': 'TURKEY', # Keep as-is
|
|
||||||
'Germany': 'GERMANY',
|
|
||||||
'England': 'ENGLAND',
|
|
||||||
'France': 'FRANCE',
|
|
||||||
'Italy': 'ITALY',
|
|
||||||
'Russia': 'RUSSIA',
|
|
||||||
'Austria': 'AUSTRIA',
|
|
||||||
'Turkey': 'TURKEY',
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized = name_mapping.get(recipient, recipient)
|
|
||||||
|
|
||||||
# Validate it's a known power
|
|
||||||
if normalized not in self.DIPLOMACY_POWERS:
|
|
||||||
return None # Invalid recipient
|
|
||||||
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
|
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -272,11 +236,11 @@ class StatisticalGameAnalyzer:
|
||||||
if not phase_data:
|
if not phase_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for power in self.powers:
|
for power in PowerEnum:
|
||||||
for response_type in response_types:
|
for response_type in response_types:
|
||||||
# Extract features for this specific power/phase/response_type combination
|
# Extract features for this specific power/phase/response_type combination
|
||||||
features = self._extract_power_phase_response_features(
|
features = self._extract_power_phase_response_features(
|
||||||
power, phase_name, response_type, llm_responses, phase_data, game_data
|
power.value, phase_name, response_type, llm_responses, phase_data, game_data
|
||||||
)
|
)
|
||||||
if features:
|
if features:
|
||||||
phase_features.append(features)
|
phase_features.append(features)
|
||||||
|
|
@ -530,7 +494,7 @@ class StatisticalGameAnalyzer:
|
||||||
game_features = []
|
game_features = []
|
||||||
game_scores = self._compute_game_scores(game_data)
|
game_scores = self._compute_game_scores(game_data)
|
||||||
|
|
||||||
for power in self.powers:
|
for power in PowerEnum:
|
||||||
features = {
|
features = {
|
||||||
# === IDENTIFIERS ===
|
# === IDENTIFIERS ===
|
||||||
'game_id': game_data.get('id', 'unknown'),
|
'game_id': game_data.get('id', 'unknown'),
|
||||||
|
|
@ -718,7 +682,8 @@ class StatisticalGameAnalyzer:
|
||||||
|
|
||||||
# Categorize by relationship
|
# Categorize by relationship
|
||||||
recipient = msg.get('recipient_power')
|
recipient = msg.get('recipient_power')
|
||||||
normalized_recipient = self._normalize_recipient_name(recipient)
|
# This will coerce some known aliases to match the 7 acceptable names (see models.py)
|
||||||
|
normalized_recipient = PowerEnum(recipient)
|
||||||
|
|
||||||
# Skip self-messages and invalid recipients
|
# Skip self-messages and invalid recipients
|
||||||
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
||||||
|
|
@ -969,7 +934,7 @@ class StatisticalGameAnalyzer:
|
||||||
break
|
break
|
||||||
|
|
||||||
# elimination turn for every power
|
# elimination turn for every power
|
||||||
elim_turn: dict[str, int | None] = {p: None for p in self.DIPLOMACY_POWERS}
|
elim_turn: dict[str, int | None] = {p: None for p in [power.value for power in PowerEnum]}
|
||||||
for idx, ph in enumerate(phases):
|
for idx, ph in enumerate(phases):
|
||||||
yr = self._phase_year(phases, idx)
|
yr = self._phase_year(phases, idx)
|
||||||
if yr is None:
|
if yr is None:
|
||||||
|
|
@ -1223,4 +1188,4 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
exit(main())
|
exit(main())
|
||||||
|
|
|
||||||
48
bot_client/config.py
Normal file
48
bot_client/config.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import datetime
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
class Configuration(BaseSettings):
|
||||||
|
DEBUG: bool = False
|
||||||
|
log_file_path: Path | None = None
|
||||||
|
DEEPSEEK_API_KEY: str | None = None
|
||||||
|
OPENAI_API_KEY: str | None = None
|
||||||
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
|
GEMINI_API_KEY: str | None = None
|
||||||
|
OPENROUTER_API_KEY: str | None = None
|
||||||
|
|
||||||
|
def __init__(self, power_name, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
|
||||||
|
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||||
|
self.log_file_path = self.log_file_path.resolve()
|
||||||
|
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.log_file_path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
self._validate_api_keys()
|
||||||
|
|
||||||
|
def _validate_api_keys(self):
|
||||||
|
"""Validate API keys at startup and issue warnings for missing keys"""
|
||||||
|
api_keys = [
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in api_keys:
|
||||||
|
value = super().__getattribute__(key)
|
||||||
|
if not value or (isinstance(value, str) and len(value) == 0):
|
||||||
|
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
"""Override to check for empty API keys at access time"""
|
||||||
|
value = super().__getattribute__(name)
|
||||||
|
|
||||||
|
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||||
|
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||||
|
|
||||||
|
return value
|
||||||
60
config.py
Normal file
60
config.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
from models import PowerEnum
|
||||||
|
|
||||||
|
|
||||||
|
class Configuration(BaseSettings):
|
||||||
|
DEBUG: bool = False
|
||||||
|
log_file_path: Path | None = None
|
||||||
|
USE_UNFORMATTED_PROMPTS: bool = False
|
||||||
|
|
||||||
|
# API Keys to be validated. Warns if they aren't present at startup, raises ValueError if you attempt to use them when they aren't present.
|
||||||
|
DEEPSEEK_API_KEY: str | None = None
|
||||||
|
OPENAI_API_KEY: str | None = None
|
||||||
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
|
GEMINI_API_KEY: str | None = None
|
||||||
|
OPENROUTER_API_KEY: str | None = None
|
||||||
|
|
||||||
|
def __init__(self, power_name: Optional[PowerEnum] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Add a '-POWER' to the end of the file name if it's for a specific power
|
||||||
|
log_power_path = "-" + power_name if power_name else None
|
||||||
|
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/logs{log_power_path} .txt")
|
||||||
|
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||||
|
self.log_file_path = self.log_file_path.resolve()
|
||||||
|
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.log_file_path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
self._validate_api_keys()
|
||||||
|
|
||||||
|
def _validate_api_keys(self):
|
||||||
|
"""Validate API keys at startup and issue warnings for missing keys"""
|
||||||
|
api_keys = [
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in api_keys:
|
||||||
|
value = super().__getattribute__(key)
|
||||||
|
if not value or (isinstance(value, str) and len(value) == 0):
|
||||||
|
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
"""Override to check for empty API keys at access time"""
|
||||||
|
value = super().__getattribute__(name)
|
||||||
|
|
||||||
|
# If this is a _KEY, it must be not None, string, and length > 0 to return. We do not validate the correctness of the key.
|
||||||
|
# e.g. "thisIsAKey" is valid in this sense.
|
||||||
|
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||||
|
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
config = Configuration()
|
||||||
38
models.py
Normal file
38
models.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# your “typo → canonical” map
|
||||||
|
_POWER_ALIASES = {
|
||||||
|
"EGMANY": "GERMANY",
|
||||||
|
"GERMAN": "GERMANY",
|
||||||
|
"UK": "ENGLAND",
|
||||||
|
"BRIT": "ENGLAND",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PowerEnum(str, Enum):
|
||||||
|
AUSTRIA = "AUSTRIA"
|
||||||
|
ENGLAND = "ENGLAND"
|
||||||
|
FRANCE = "FRANCE"
|
||||||
|
GERMANY = "GERMANY"
|
||||||
|
ITALY = "ITALY"
|
||||||
|
RUSSIA = "RUSSIA"
|
||||||
|
TURKEY = "TURKEY"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: Any) -> Optional["Enum"]:
|
||||||
|
"""
|
||||||
|
Called when you do PowerEnum(value) and `value` isn't one of the raw enum values.
|
||||||
|
Here we normalize strings to upper‐stripped, apply aliases, then retry.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.upper().strip()
|
||||||
|
# apply any synonyms/typos
|
||||||
|
normalized = _POWER_ALIASES.get(normalized, normalized)
|
||||||
|
# look up in the normal value→member map
|
||||||
|
member = cls._value2member_map_.get(normalized)
|
||||||
|
if member is not None:
|
||||||
|
return member
|
||||||
|
|
||||||
|
# by default, let Enum raise the ValueError
|
||||||
|
return super()._missing_(value)
|
||||||
|
|
@ -5,28 +5,32 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anthropic>=0.54.0",
|
"anthropic>=0.54.0",
|
||||||
"bcrypt>=4.3.0",
|
"bcrypt>=4.3.0",
|
||||||
"coloredlogs>=15.0.1",
|
"coloredlogs>=15.0.1",
|
||||||
"dotenv>=0.9.9",
|
"dotenv>=0.9.9",
|
||||||
"google-genai>=1.21.1",
|
"google-genai>=1.21.1",
|
||||||
"google-generativeai>=0.8.5",
|
"google-generativeai>=0.8.5",
|
||||||
"json-repair>=0.47.2",
|
"json-repair>=0.47.2",
|
||||||
"json5>=0.12.0",
|
"json5>=0.12.0",
|
||||||
"matplotlib>=3.10.3",
|
"matplotlib>=3.10.3",
|
||||||
"openai>=1.90.0",
|
"openai>=1.90.0",
|
||||||
"pylint>=2.3.0",
|
"pydantic-settings>=2.10.1",
|
||||||
"pytest>=4.4.0",
|
"pylint>=2.3.0",
|
||||||
"pytest-xdist>=3.7.0",
|
"pytest>=4.4.0",
|
||||||
"python-dateutil>=2.9.0.post0",
|
"pytest-xdist>=3.7.0",
|
||||||
"pytz>=2025.2",
|
"python-dateutil>=2.9.0.post0",
|
||||||
"scipy>=1.16.0",
|
"pytz>=2025.2",
|
||||||
"seaborn>=0.13.2",
|
"scipy>=1.16.0",
|
||||||
"sphinx>=8.2.3",
|
"seaborn>=0.13.2",
|
||||||
"sphinx-copybutton>=0.5.2",
|
"sphinx>=8.2.3",
|
||||||
"sphinx-rtd-theme>=3.0.2",
|
"sphinx-copybutton>=0.5.2",
|
||||||
"together>=1.5.17",
|
"sphinx-rtd-theme>=3.0.2",
|
||||||
"tornado>=5.0",
|
"together>=1.5.17",
|
||||||
"tqdm>=4.67.1",
|
"tornado>=5.0",
|
||||||
"ujson>=5.10.0",
|
"tqdm>=4.67.1",
|
||||||
|
"ujson>=5.10.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 150
|
||||||
|
|
|
||||||
21
tests/test_models.py
Normal file
21
tests/test_models.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from models import PowerEnum
|
||||||
|
|
||||||
|
|
||||||
|
def test_power_name_aliases():
|
||||||
|
# Test all aliases defined in _POWER_ALIASES
|
||||||
|
assert PowerEnum("UK") == PowerEnum.ENGLAND
|
||||||
|
assert PowerEnum("BRIT") == PowerEnum.ENGLAND
|
||||||
|
assert PowerEnum("EGMANY") == PowerEnum.GERMANY
|
||||||
|
assert PowerEnum("GERMAN") == PowerEnum.GERMANY
|
||||||
|
|
||||||
|
# Test direct enum values (no alias needed)
|
||||||
|
assert PowerEnum("AUSTRIA") == PowerEnum.AUSTRIA
|
||||||
|
assert PowerEnum("FRANCE") == PowerEnum.FRANCE
|
||||||
|
|
||||||
|
# Test case insensitivity
|
||||||
|
assert PowerEnum("france") == PowerEnum.FRANCE
|
||||||
|
assert PowerEnum("iTaLy") == PowerEnum.ITALY
|
||||||
|
|
||||||
|
# Test with whitespace
|
||||||
|
assert PowerEnum(" RUSSIA ") == PowerEnum.RUSSIA
|
||||||
|
assert PowerEnum("TURKEY ") == PowerEnum.TURKEY
|
||||||
Loading…
Add table
Add a link
Reference in a new issue