mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
state update fixes & streamline prompts
This commit is contained in:
parent
1f154a7073
commit
b4a56126ec
17 changed files with 710 additions and 281 deletions
|
|
@ -69,6 +69,12 @@ class StatisticalGameAnalyzer:
|
|||
'order_generation', 'order_diary', 'state_update_parsing_empty_or_invalid_data',
|
||||
'diary_consolidation', 'state_update_partial_data', 'state_update_no_response'
|
||||
]
|
||||
|
||||
ORDER_TYPES = [
|
||||
"move", "hold", "support", "convoy",
|
||||
"build", "disband", "waive", "other",
|
||||
"retreat"
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize analyzer with configuration constants."""
|
||||
|
|
@ -234,6 +240,103 @@ class StatisticalGameAnalyzer:
|
|||
|
||||
return responses
|
||||
|
||||
def _extract_order_results_features(self, power: str, phase_data: dict) -> dict:
|
||||
"""
|
||||
Count orders and outcomes for a single power in one phase and add
|
||||
a success-rate (0-1) for every order type.
|
||||
"""
|
||||
features: dict[str, float | int] = {}
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
for metric in ("total", "success", "bounce", "void", "invalid"):
|
||||
features[f"orders_{plural}_{metric}"] = 0
|
||||
features[f"orders_{plural}_success_rate"] = 0.0 # ← new
|
||||
|
||||
orders_by_type = phase_data.get("order_results", {}).get(power, {})
|
||||
if not orders_by_type:
|
||||
return features
|
||||
|
||||
for otype, order_list in orders_by_type.items():
|
||||
otype = otype.lower()
|
||||
if otype not in self.ORDER_TYPES:
|
||||
otype = "other"
|
||||
plural = f"{otype}s" if not otype.endswith("s") else otype
|
||||
|
||||
for entry in order_list:
|
||||
result = str(entry.get("result", "")).lower().strip()
|
||||
key_base = f"orders_{plural}"
|
||||
features[f"{key_base}_total"] += 1
|
||||
match result:
|
||||
case "success":
|
||||
features[f"{key_base}_success"] += 1
|
||||
case "bounce":
|
||||
features[f"{key_base}_bounce"] += 1
|
||||
case "invalid":
|
||||
features[f"{key_base}_invalid"] += 1
|
||||
case _ if result in ("void", "void: no effect", ""):
|
||||
features[f"{key_base}_void"] += 1
|
||||
|
||||
# ── derive success rates ──
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
succ = features[f"orders_{plural}_success"]
|
||||
tot = features[f"orders_{plural}_total"]
|
||||
features[f"orders_{plural}_success_rate"] = succ / tot if tot else 0.0
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
# ────────────────── GAME-LEVEL ORDER TOTALS ──────────────────
|
||||
def _aggregate_order_results(self, power: str, game_data: dict) -> dict:
|
||||
"""
|
||||
Sum every order-type/result pair over *all* phases for one power
|
||||
and add success-rate (0-1) columns.
|
||||
"""
|
||||
totals: dict[str, float | int] = {}
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
for metric in ("total", "success", "bounce", "void", "invalid"):
|
||||
totals[f"orders_{plural}_{metric}"] = 0
|
||||
totals[f"orders_{plural}_success_rate"] = 0.0 # ← new
|
||||
|
||||
for phase in game_data.get("phases", []):
|
||||
orders_by_type = phase.get("order_results", {}).get(power, {})
|
||||
if not orders_by_type:
|
||||
continue
|
||||
|
||||
for otype, order_list in orders_by_type.items():
|
||||
otype = otype.lower()
|
||||
if otype not in self.ORDER_TYPES:
|
||||
otype = "other"
|
||||
plural = f"{otype}s" if not otype.endswith("s") else otype
|
||||
|
||||
for entry in order_list:
|
||||
result = str(entry.get("result", "")).lower().strip()
|
||||
key_base = f"orders_{plural}"
|
||||
totals[f"{key_base}_total"] += 1
|
||||
match result:
|
||||
case "success":
|
||||
totals[f"{key_base}_success"] += 1
|
||||
case "bounce":
|
||||
totals[f"{key_base}_bounce"] += 1
|
||||
case "invalid":
|
||||
totals[f"{key_base}_invalid"] += 1
|
||||
case _ if result in ("void", "void: no effect", ""):
|
||||
totals[f"{key_base}_void"] += 1
|
||||
|
||||
# ── derive success rates ──
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
succ = totals[f"orders_{plural}_success"]
|
||||
tot = totals[f"orders_{plural}_total"]
|
||||
totals[f"orders_{plural}_success_rate"] = succ / tot if tot else 0.0
|
||||
|
||||
return totals
|
||||
|
||||
|
||||
|
||||
|
||||
def _extract_phase_features(self, llm_responses: List[dict], game_data: dict) -> List[dict]:
|
||||
"""Extract phase-level features for all powers, phases, and response types."""
|
||||
phase_features = []
|
||||
|
|
@ -294,6 +397,10 @@ class StatisticalGameAnalyzer:
|
|||
# === FAILURE ANALYSIS (HARD MODE) ===
|
||||
failure_metrics = self._analyze_failures(power, phase, response_type, llm_responses)
|
||||
features.update(failure_metrics)
|
||||
|
||||
# === ORDER-RESULT METRICS ===
|
||||
order_result_features = self._extract_order_results_features(power, phase_data)
|
||||
features.update(order_result_features)
|
||||
|
||||
|
||||
# Add response-type specific features
|
||||
|
|
@ -794,7 +901,10 @@ class StatisticalGameAnalyzer:
|
|||
if total_calls > 0:
|
||||
features['overall_failure_rate_percentage'] = (total_failures / total_calls) * 100.0
|
||||
features['overall_success_rate_percentage'] = (total_successes / total_calls) * 100.0
|
||||
|
||||
|
||||
# === ORDER TOTALS (whole game) ===
|
||||
order_totals = self._aggregate_order_results(power, game_data)
|
||||
features.update(order_totals)
|
||||
|
||||
# Helper methods
|
||||
|
||||
|
|
@ -1067,6 +1177,15 @@ class StatisticalGameAnalyzer:
|
|||
'military_units_gained_vs_prev_phase',
|
||||
'relationships'
|
||||
]
|
||||
|
||||
# ensure order columns
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
for suffix in ("total", "success", "bounce", "void", "invalid", "success_rate"):
|
||||
col = f"orders_{plural}_{suffix}"
|
||||
if col not in fieldnames:
|
||||
fieldnames.append(col)
|
||||
|
||||
|
||||
# Ensure all actual fields are included (in case we missed any)
|
||||
actual_fields = set()
|
||||
|
|
@ -1140,6 +1259,17 @@ class StatisticalGameAnalyzer:
|
|||
# === Diplobench style single scalar game score ===
|
||||
'game_score',
|
||||
]
|
||||
|
||||
# ensure order-total columns
|
||||
for ot in self.ORDER_TYPES:
|
||||
plural = f"{ot}s" if not ot.endswith("s") else ot
|
||||
base = f"orders_{plural}_total"
|
||||
for suffix in ("total", "success", "bounce", "void", "invalid", "success_rate"):
|
||||
col = f"orders_{plural}_{suffix}"
|
||||
if col not in fieldnames:
|
||||
fieldnames.append(col)
|
||||
|
||||
|
||||
|
||||
# Ensure all actual fields are included
|
||||
actual_fields = set()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue