mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-05-02 17:46:00 +00:00
Adding PowerEnum as a seperate model
The PowerEnum correctly handles some misspellings. It can be easily expanded to handle more within the _POWER_ALIASES dict.
This commit is contained in:
parent
a241e34496
commit
540c2003e8
8 changed files with 1236 additions and 1053 deletions
|
|
@ -29,6 +29,7 @@ from collections import defaultdict, Counter
|
|||
import re
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
import statistics
|
||||
from ..models import PowerEnum
|
||||
|
||||
class StatisticalGameAnalyzer:
|
||||
"""Production-ready analyzer for AI Diplomacy game statistics.
|
||||
|
|
@ -47,7 +48,6 @@ class StatisticalGameAnalyzer:
|
|||
'Ally': 2
|
||||
}
|
||||
|
||||
DIPLOMACY_POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
||||
|
||||
# Complete list of response types found in actual data
|
||||
RESPONSE_TYPES = [
|
||||
|
|
@ -59,44 +59,8 @@ class StatisticalGameAnalyzer:
|
|||
def __init__(self):
|
||||
"""Initialize analyzer with configuration constants."""
|
||||
self.relationship_values = self.RELATIONSHIP_VALUES
|
||||
self.powers = self.DIPLOMACY_POWERS
|
||||
|
||||
def _normalize_recipient_name(self, recipient: str) -> str:
|
||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||
if not recipient:
|
||||
return recipient
|
||||
|
||||
recipient = recipient.upper().strip()
|
||||
|
||||
# Handle common LLM typos and abbreviations found in data
|
||||
name_mapping = {
|
||||
'EGMANY': 'GERMANY',
|
||||
'GERMAN': 'GERMANY',
|
||||
'UK': 'ENGLAND',
|
||||
'BRIT': 'ENGLAND',
|
||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
||||
'FRANCE': 'FRANCE', # Keep as-is
|
||||
'GERMANY': 'GERMANY', # Keep as-is
|
||||
'ITALY': 'ITALY', # Keep as-is
|
||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
||||
'TURKEY': 'TURKEY', # Keep as-is
|
||||
'Germany': 'GERMANY',
|
||||
'England': 'ENGLAND',
|
||||
'France': 'FRANCE',
|
||||
'Italy': 'ITALY',
|
||||
'Russia': 'RUSSIA',
|
||||
'Austria': 'AUSTRIA',
|
||||
'Turkey': 'TURKEY',
|
||||
}
|
||||
|
||||
normalized = name_mapping.get(recipient, recipient)
|
||||
|
||||
# Validate it's a known power
|
||||
if normalized not in self.DIPLOMACY_POWERS:
|
||||
return None # Invalid recipient
|
||||
|
||||
return normalized
|
||||
|
||||
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
|
||||
"""
|
||||
|
|
@ -272,11 +236,11 @@ class StatisticalGameAnalyzer:
|
|||
if not phase_data:
|
||||
continue
|
||||
|
||||
for power in self.powers:
|
||||
for power in PowerEnum:
|
||||
for response_type in response_types:
|
||||
# Extract features for this specific power/phase/response_type combination
|
||||
features = self._extract_power_phase_response_features(
|
||||
power, phase_name, response_type, llm_responses, phase_data, game_data
|
||||
power.value, phase_name, response_type, llm_responses, phase_data, game_data
|
||||
)
|
||||
if features:
|
||||
phase_features.append(features)
|
||||
|
|
@ -530,7 +494,7 @@ class StatisticalGameAnalyzer:
|
|||
game_features = []
|
||||
game_scores = self._compute_game_scores(game_data)
|
||||
|
||||
for power in self.powers:
|
||||
for power in PowerEnum:
|
||||
features = {
|
||||
# === IDENTIFIERS ===
|
||||
'game_id': game_data.get('id', 'unknown'),
|
||||
|
|
@ -718,7 +682,8 @@ class StatisticalGameAnalyzer:
|
|||
|
||||
# Categorize by relationship
|
||||
recipient = msg.get('recipient_power')
|
||||
normalized_recipient = self._normalize_recipient_name(recipient)
|
||||
# This will coerce some known aliases to match the 7 acceptable names (see models.py)
|
||||
normalized_recipient = PowerEnum(recipient)
|
||||
|
||||
# Skip self-messages and invalid recipients
|
||||
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
||||
|
|
@ -969,7 +934,7 @@ class StatisticalGameAnalyzer:
|
|||
break
|
||||
|
||||
# elimination turn for every power
|
||||
elim_turn: dict[str, int | None] = {p: None for p in self.DIPLOMACY_POWERS}
|
||||
elim_turn: dict[str, int | None] = {p: None for p in [power.value for power in PowerEnum]}
|
||||
for idx, ph in enumerate(phases):
|
||||
yr = self._phase_year(phases, idx)
|
||||
if yr is None:
|
||||
|
|
@ -1223,4 +1188,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
exit(main())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue