Adding PowerEnum as a seperate model

The PowerEnum correctly handles some misspellings. It can be easily
expanded to handle more within the _POWER_ALIASES dict.
This commit is contained in:
Tyler Marques 2025-07-03 12:06:47 -07:00
parent a241e34496
commit 540c2003e8
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
8 changed files with 1236 additions and 1053 deletions

File diff suppressed because it is too large Load diff

View file

@ -29,6 +29,7 @@ from collections import defaultdict, Counter
import re import re
from typing import Dict, List, Tuple, Optional, Any from typing import Dict, List, Tuple, Optional, Any
import statistics import statistics
from ..models import PowerEnum
class StatisticalGameAnalyzer: class StatisticalGameAnalyzer:
"""Production-ready analyzer for AI Diplomacy game statistics. """Production-ready analyzer for AI Diplomacy game statistics.
@ -47,7 +48,6 @@ class StatisticalGameAnalyzer:
'Ally': 2 'Ally': 2
} }
DIPLOMACY_POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
# Complete list of response types found in actual data # Complete list of response types found in actual data
RESPONSE_TYPES = [ RESPONSE_TYPES = [
@ -59,44 +59,8 @@ class StatisticalGameAnalyzer:
def __init__(self): def __init__(self):
"""Initialize analyzer with configuration constants.""" """Initialize analyzer with configuration constants."""
self.relationship_values = self.RELATIONSHIP_VALUES self.relationship_values = self.RELATIONSHIP_VALUES
self.powers = self.DIPLOMACY_POWERS
def _normalize_recipient_name(self, recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
return recipient
recipient = recipient.upper().strip()
# Handle common LLM typos and abbreviations found in data
name_mapping = {
'EGMANY': 'GERMANY',
'GERMAN': 'GERMANY',
'UK': 'ENGLAND',
'BRIT': 'ENGLAND',
'ENGLAND': 'ENGLAND', # Keep as-is
'FRANCE': 'FRANCE', # Keep as-is
'GERMANY': 'GERMANY', # Keep as-is
'ITALY': 'ITALY', # Keep as-is
'AUSTRIA': 'AUSTRIA', # Keep as-is
'RUSSIA': 'RUSSIA', # Keep as-is
'TURKEY': 'TURKEY', # Keep as-is
'Germany': 'GERMANY',
'England': 'ENGLAND',
'France': 'FRANCE',
'Italy': 'ITALY',
'Russia': 'RUSSIA',
'Austria': 'AUSTRIA',
'Turkey': 'TURKEY',
}
normalized = name_mapping.get(recipient, recipient)
# Validate it's a known power
if normalized not in self.DIPLOMACY_POWERS:
return None # Invalid recipient
return normalized
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]: def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
""" """
@ -272,11 +236,11 @@ class StatisticalGameAnalyzer:
if not phase_data: if not phase_data:
continue continue
for power in self.powers: for power in PowerEnum:
for response_type in response_types: for response_type in response_types:
# Extract features for this specific power/phase/response_type combination # Extract features for this specific power/phase/response_type combination
features = self._extract_power_phase_response_features( features = self._extract_power_phase_response_features(
power, phase_name, response_type, llm_responses, phase_data, game_data power.value, phase_name, response_type, llm_responses, phase_data, game_data
) )
if features: if features:
phase_features.append(features) phase_features.append(features)
@ -530,7 +494,7 @@ class StatisticalGameAnalyzer:
game_features = [] game_features = []
game_scores = self._compute_game_scores(game_data) game_scores = self._compute_game_scores(game_data)
for power in self.powers: for power in PowerEnum:
features = { features = {
# === IDENTIFIERS === # === IDENTIFIERS ===
'game_id': game_data.get('id', 'unknown'), 'game_id': game_data.get('id', 'unknown'),
@ -718,7 +682,8 @@ class StatisticalGameAnalyzer:
# Categorize by relationship # Categorize by relationship
recipient = msg.get('recipient_power') recipient = msg.get('recipient_power')
normalized_recipient = self._normalize_recipient_name(recipient) # This will coerce some known aliases to match the 7 acceptable names (see models.py)
normalized_recipient = PowerEnum(recipient)
# Skip self-messages and invalid recipients # Skip self-messages and invalid recipients
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships: if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
@ -969,7 +934,7 @@ class StatisticalGameAnalyzer:
break break
# elimination turn for every power # elimination turn for every power
elim_turn: dict[str, int | None] = {p: None for p in self.DIPLOMACY_POWERS} elim_turn: dict[str, int | None] = {p: None for p in [power.value for power in PowerEnum]}
for idx, ph in enumerate(phases): for idx, ph in enumerate(phases):
yr = self._phase_year(phases, idx) yr = self._phase_year(phases, idx)
if yr is None: if yr is None:
@ -1223,4 +1188,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
exit(main()) exit(main())

48
bot_client/config.py Normal file
View file

@ -0,0 +1,48 @@
import datetime
from pydantic_settings import BaseSettings
from pathlib import Path
import warnings
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path | None = None
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None
def __init__(self, power_name, **kwargs):
super().__init__(**kwargs)
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
# Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
self.log_file_path.touch(exist_ok=True)
self._validate_api_keys()
def _validate_api_keys(self):
"""Validate API keys at startup and issue warnings for missing keys"""
api_keys = [
"DEEPSEEK_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
"OPENROUTER_API_KEY",
]
for key in api_keys:
value = super().__getattribute__(key)
if not value or (isinstance(value, str) and len(value) == 0):
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
def __getattribute__(self, name):
"""Override to check for empty API keys at access time"""
value = super().__getattribute__(name)
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
return value

60
config.py Normal file
View file

@ -0,0 +1,60 @@
import datetime
from typing import Optional
from pydantic_settings import BaseSettings
from pathlib import Path
import warnings
from models import PowerEnum
class Configuration(BaseSettings):
DEBUG: bool = False
log_file_path: Path | None = None
USE_UNFORMATTED_PROMPTS: bool = False
# API Keys to be validated. Warns if they aren't present at startup, raises ValueError if you attempt to use them when they aren't present.
DEEPSEEK_API_KEY: str | None = None
OPENAI_API_KEY: str | None = None
ANTHROPIC_API_KEY: str | None = None
GEMINI_API_KEY: str | None = None
OPENROUTER_API_KEY: str | None = None
def __init__(self, power_name: Optional[PowerEnum] = None, **kwargs):
super().__init__(**kwargs)
# Add a '-POWER' to the end of the file name if it's for a specific power
log_power_path = "-" + power_name if power_name else None
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/logs{log_power_path} .txt")
# Make the path absolute, gets rid of weirdness of calling this in different places
self.log_file_path = self.log_file_path.resolve()
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
self.log_file_path.touch(exist_ok=True)
self._validate_api_keys()
def _validate_api_keys(self):
"""Validate API keys at startup and issue warnings for missing keys"""
api_keys = [
"DEEPSEEK_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
"OPENROUTER_API_KEY",
]
for key in api_keys:
value = super().__getattribute__(key)
if not value or (isinstance(value, str) and len(value) == 0):
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
def __getattribute__(self, name):
"""Override to check for empty API keys at access time"""
value = super().__getattribute__(name)
# If this is a _KEY, it must be not None, string, and length > 0 to return. We do not validate the correctness of the key.
# e.g. "thisIsAKey" is valid in this sense.
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
return value
config = Configuration()

38
models.py Normal file
View file

@ -0,0 +1,38 @@
from enum import Enum
from typing import Any, Optional
# your “typo → canonical” map
_POWER_ALIASES = {
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
}
class PowerEnum(str, Enum):
AUSTRIA = "AUSTRIA"
ENGLAND = "ENGLAND"
FRANCE = "FRANCE"
GERMANY = "GERMANY"
ITALY = "ITALY"
RUSSIA = "RUSSIA"
TURKEY = "TURKEY"
@classmethod
def _missing_(cls, value: Any) -> Optional["Enum"]:
"""
Called when you do PowerEnum(value) and `value` isn't one of the raw enum values.
Here we normalize strings to upperstripped, apply aliases, then retry.
"""
if isinstance(value, str):
normalized = value.upper().strip()
# apply any synonyms/typos
normalized = _POWER_ALIASES.get(normalized, normalized)
# look up in the normal value→member map
member = cls._value2member_map_.get(normalized)
if member is not None:
return member
# by default, let Enum raise the ValueError
return super()._missing_(value)

View file

@ -5,28 +5,32 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"anthropic>=0.54.0", "anthropic>=0.54.0",
"bcrypt>=4.3.0", "bcrypt>=4.3.0",
"coloredlogs>=15.0.1", "coloredlogs>=15.0.1",
"dotenv>=0.9.9", "dotenv>=0.9.9",
"google-genai>=1.21.1", "google-genai>=1.21.1",
"google-generativeai>=0.8.5", "google-generativeai>=0.8.5",
"json-repair>=0.47.2", "json-repair>=0.47.2",
"json5>=0.12.0", "json5>=0.12.0",
"matplotlib>=3.10.3", "matplotlib>=3.10.3",
"openai>=1.90.0", "openai>=1.90.0",
"pylint>=2.3.0", "pydantic-settings>=2.10.1",
"pytest>=4.4.0", "pylint>=2.3.0",
"pytest-xdist>=3.7.0", "pytest>=4.4.0",
"python-dateutil>=2.9.0.post0", "pytest-xdist>=3.7.0",
"pytz>=2025.2", "python-dateutil>=2.9.0.post0",
"scipy>=1.16.0", "pytz>=2025.2",
"seaborn>=0.13.2", "scipy>=1.16.0",
"sphinx>=8.2.3", "seaborn>=0.13.2",
"sphinx-copybutton>=0.5.2", "sphinx>=8.2.3",
"sphinx-rtd-theme>=3.0.2", "sphinx-copybutton>=0.5.2",
"together>=1.5.17", "sphinx-rtd-theme>=3.0.2",
"tornado>=5.0", "together>=1.5.17",
"tqdm>=4.67.1", "tornado>=5.0",
"ujson>=5.10.0", "tqdm>=4.67.1",
"ujson>=5.10.0",
] ]
[tool.ruff]
line-length = 150

21
tests/test_models.py Normal file
View file

@ -0,0 +1,21 @@
from models import PowerEnum
def test_power_name_aliases():
# Test all aliases defined in _POWER_ALIASES
assert PowerEnum("UK") == PowerEnum.ENGLAND
assert PowerEnum("BRIT") == PowerEnum.ENGLAND
assert PowerEnum("EGMANY") == PowerEnum.GERMANY
assert PowerEnum("GERMAN") == PowerEnum.GERMANY
# Test direct enum values (no alias needed)
assert PowerEnum("AUSTRIA") == PowerEnum.AUSTRIA
assert PowerEnum("FRANCE") == PowerEnum.FRANCE
# Test case insensitivity
assert PowerEnum("france") == PowerEnum.FRANCE
assert PowerEnum("iTaLy") == PowerEnum.ITALY
# Test with whitespace
assert PowerEnum(" RUSSIA ") == PowerEnum.RUSSIA
assert PowerEnum("TURKEY ") == PowerEnum.TURKEY

1373
uv.lock generated

File diff suppressed because it is too large Load diff