mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
Adding PowerEnum as a seperate model
The PowerEnum correctly handles some misspellings. It can be easily expanded to handle more within the _POWER_ALIASES dict.
This commit is contained in:
parent
a241e34496
commit
540c2003e8
8 changed files with 1236 additions and 1053 deletions
|
|
@ -5,10 +5,13 @@ import json
|
||||||
import re
|
import re
|
||||||
import json_repair
|
import json_repair
|
||||||
import json5 # More forgiving JSON parser
|
import json5 # More forgiving JSON parser
|
||||||
import ast # For literal_eval
|
import ast
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||||
from .clients import BaseModelClient, load_model_client
|
from .clients import BaseModelClient, load_model_client
|
||||||
|
|
||||||
# Import load_prompt and the new logging wrapper from utils
|
# Import load_prompt and the new logging wrapper from utils
|
||||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
||||||
from .prompt_constructor import build_context_prompt # Added import
|
from .prompt_constructor import build_context_prompt # Added import
|
||||||
|
|
@ -22,6 +25,7 @@ logger = logging.getLogger(__name__)
|
||||||
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
|
ALL_POWERS = frozenset({"AUSTRIA", "ENGLAND", "FRANCE", "GERMANY", "ITALY", "RUSSIA", "TURKEY"})
|
||||||
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
ALLOWED_RELATIONSHIPS = ["Enemy", "Unfriendly", "Neutral", "Friendly", "Ally"]
|
||||||
|
|
||||||
|
|
||||||
# == New: Helper function to load prompt files reliably ==
|
# == New: Helper function to load prompt files reliably ==
|
||||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optional[str]:
|
||||||
"""Loads a prompt template from the prompts directory."""
|
"""Loads a prompt template from the prompts directory."""
|
||||||
|
|
@ -31,10 +35,10 @@ def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optio
|
||||||
else:
|
else:
|
||||||
# Construct path relative to this file's location
|
# Construct path relative to this file's location
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||||
filepath = os.path.join(default_prompts_dir, filename)
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(f"Prompt file not found: {filepath}")
|
logger.error(f"Prompt file not found: {filepath}")
|
||||||
|
|
@ -43,12 +47,14 @@ def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> Optio
|
||||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class DiplomacyAgent:
|
class DiplomacyAgent:
|
||||||
"""
|
"""
|
||||||
Represents a stateful AI agent playing as a specific power in Diplomacy.
|
Represents a stateful AI agent playing as a specific power in Diplomacy.
|
||||||
It holds the agent's goals, relationships, and private journal,
|
It holds the agent's goals, relationships, and private journal,
|
||||||
and uses a BaseModelClient instance to interact with the LLM.
|
and uses a BaseModelClient instance to interact with the LLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -104,12 +110,9 @@ class DiplomacyAgent:
|
||||||
system_prompt_content = load_prompt(power_prompt_path)
|
system_prompt_content = load_prompt(power_prompt_path)
|
||||||
|
|
||||||
if not system_prompt_content:
|
if not system_prompt_content:
|
||||||
logger.warning(
|
logger.warning(f"Power-specific prompt not found at {power_prompt_path}. Falling back to default.")
|
||||||
f"Power-specific prompt not found at {power_prompt_path}. Falling back to default."
|
|
||||||
)
|
|
||||||
system_prompt_content = load_prompt(default_prompt_path)
|
system_prompt_content = load_prompt(default_prompt_path)
|
||||||
|
|
||||||
|
|
||||||
if system_prompt_content: # Ensure we actually have content before setting
|
if system_prompt_content: # Ensure we actually have content before setting
|
||||||
self.client.set_system_prompt(system_prompt_content)
|
self.client.set_system_prompt(system_prompt_content)
|
||||||
else:
|
else:
|
||||||
|
|
@ -131,11 +134,16 @@ class DiplomacyAgent:
|
||||||
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text) # Remove newlines before keys
|
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text) # Remove newlines before keys
|
||||||
# Fix specific patterns that cause trouble
|
# Fix specific patterns that cause trouble
|
||||||
problematic_patterns = [
|
problematic_patterns = [
|
||||||
'negotiation_summary', 'relationship_updates', 'updated_relationships',
|
"negotiation_summary",
|
||||||
'order_summary', 'goals', 'relationships', 'intent'
|
"relationship_updates",
|
||||||
|
"updated_relationships",
|
||||||
|
"order_summary",
|
||||||
|
"goals",
|
||||||
|
"relationships",
|
||||||
|
"intent",
|
||||||
]
|
]
|
||||||
for pattern in problematic_patterns:
|
for pattern in problematic_patterns:
|
||||||
text = re.sub(fr'\n\s*"{pattern}"', f'"{pattern}"', text)
|
text = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', text)
|
||||||
|
|
||||||
# Try different patterns to extract JSON
|
# Try different patterns to extract JSON
|
||||||
# Order matters - try most specific patterns first
|
# Order matters - try most specific patterns first
|
||||||
|
|
@ -169,7 +177,9 @@ class DiplomacyAgent:
|
||||||
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with pattern {pattern_idx}, match {match_idx}")
|
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with pattern {pattern_idx}, match {match_idx}")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Parsed JSON with pattern {pattern_idx}, match {match_idx}, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Parsed JSON with pattern {pattern_idx}, match {match_idx}, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
except json.JSONDecodeError as e_initial:
|
except json.JSONDecodeError as e_initial:
|
||||||
logger.debug(f"[{self.power_name}] Standard JSON parse failed: {e_initial}")
|
logger.debug(f"[{self.power_name}] Standard JSON parse failed: {e_initial}")
|
||||||
|
|
||||||
|
|
@ -179,16 +189,20 @@ class DiplomacyAgent:
|
||||||
cleaned_match_candidate = json_text
|
cleaned_match_candidate = json_text
|
||||||
|
|
||||||
# Pattern 1: Removes 'Sentence.' when followed by ',', '}', or ']'
|
# Pattern 1: Removes 'Sentence.' when followed by ',', '}', or ']'
|
||||||
cleaned_match_candidate = re.sub(r'\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=[,\}\]])', '', cleaned_match_candidate)
|
cleaned_match_candidate = re.sub(
|
||||||
|
r"\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=[,\}\]])", "", cleaned_match_candidate
|
||||||
|
)
|
||||||
|
|
||||||
# Pattern 2: Removes 'Sentence.' when it's at the very end, before the final '}' of the current scope
|
# Pattern 2: Removes 'Sentence.' when it's at the very end, before the final '}' of the current scope
|
||||||
cleaned_match_candidate = re.sub(r'\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=\s*\}\s*$)', '', cleaned_match_candidate)
|
cleaned_match_candidate = re.sub(
|
||||||
|
r"\s*([A-Z][\w\s,]*?\.(?:\s+[A-Z][\w\s,]*?\.)*)\s*(?=\s*\}\s*$)", "", cleaned_match_candidate
|
||||||
|
)
|
||||||
|
|
||||||
# Pattern 3: Fix for newlines and spaces before JSON keys (common problem with LLMs)
|
# Pattern 3: Fix for newlines and spaces before JSON keys (common problem with LLMs)
|
||||||
cleaned_match_candidate = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', cleaned_match_candidate)
|
cleaned_match_candidate = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', cleaned_match_candidate)
|
||||||
|
|
||||||
# Pattern 4: Fix trailing commas in JSON objects
|
# Pattern 4: Fix trailing commas in JSON objects
|
||||||
cleaned_match_candidate = re.sub(r',\s*}', '}', cleaned_match_candidate)
|
cleaned_match_candidate = re.sub(r",\s*}", "}", cleaned_match_candidate)
|
||||||
|
|
||||||
# Pattern 5: Handle specific known problematic patterns
|
# Pattern 5: Handle specific known problematic patterns
|
||||||
for pattern in problematic_patterns:
|
for pattern in problematic_patterns:
|
||||||
|
|
@ -211,7 +225,9 @@ class DiplomacyAgent:
|
||||||
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json5")
|
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json5")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Parsed with json5, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Parsed with json5, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.power_name}] json5 parse failed: {e}")
|
logger.debug(f"[{self.power_name}] json5 parse failed: {e}")
|
||||||
|
|
||||||
|
|
@ -222,7 +238,9 @@ class DiplomacyAgent:
|
||||||
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json-repair")
|
logger.debug(f"[{self.power_name}] Successfully parsed JSON object with json-repair")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Parsed with json-repair, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Parsed with json-repair, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.power_name}] json-repair failed: {e}")
|
logger.debug(f"[{self.power_name}] json-repair failed: {e}")
|
||||||
|
|
||||||
|
|
@ -235,8 +253,8 @@ class DiplomacyAgent:
|
||||||
# Regex to find **key:** value, where value can be multi-line until next **key:** or end of string
|
# Regex to find **key:** value, where value can be multi-line until next **key:** or end of string
|
||||||
md_pattern = r"\*\*(?P<key>[^:]+):\*\*\s*(?P<value>[\s\S]*?)(?=(?:\n\s*\*\*|$))"
|
md_pattern = r"\*\*(?P<key>[^:]+):\*\*\s*(?P<value>[\s\S]*?)(?=(?:\n\s*\*\*|$))"
|
||||||
for match in re.finditer(md_pattern, text, re.DOTALL):
|
for match in re.finditer(md_pattern, text, re.DOTALL):
|
||||||
key_name = match.group('key').strip()
|
key_name = match.group("key").strip()
|
||||||
value_str = match.group('value').strip()
|
value_str = match.group("value").strip()
|
||||||
try:
|
try:
|
||||||
# Attempt to evaluate the value string as a Python literal
|
# Attempt to evaluate the value string as a Python literal
|
||||||
# This handles lists, strings, numbers, booleans, None
|
# This handles lists, strings, numbers, booleans, None
|
||||||
|
|
@ -247,7 +265,9 @@ class DiplomacyAgent:
|
||||||
# Or it could be genuinely malformed. We'll take it as a string if it's not empty.
|
# Or it could be genuinely malformed. We'll take it as a string if it's not empty.
|
||||||
if value_str: # Only add if it's a non-empty string
|
if value_str: # Only add if it's a non-empty string
|
||||||
markdown_data[key_name] = value_str # Store as string
|
markdown_data[key_name] = value_str # Store as string
|
||||||
logger.debug(f"[{self.power_name}] ast.literal_eval failed for key '{key_name}', value '{value_str[:50]}...': {e_ast}. Storing as string if non-empty.")
|
logger.debug(
|
||||||
|
f"[{self.power_name}] ast.literal_eval failed for key '{key_name}', value '{value_str[:50]}...': {e_ast}. Storing as string if non-empty."
|
||||||
|
)
|
||||||
|
|
||||||
if markdown_data: # If we successfully extracted any key-value pairs this way
|
if markdown_data: # If we successfully extracted any key-value pairs this way
|
||||||
# Check if essential keys are present, if needed, or just return if any data found
|
# Check if essential keys are present, if needed, or just return if any data found
|
||||||
|
|
@ -262,17 +282,13 @@ class DiplomacyAgent:
|
||||||
# Fallback: Try to find ANY JSON-like structure
|
# Fallback: Try to find ANY JSON-like structure
|
||||||
try:
|
try:
|
||||||
# Find the first { and last }
|
# Find the first { and last }
|
||||||
start = text.find('{')
|
start = text.find("{")
|
||||||
end = text.rfind('}') + 1 # Include the closing brace
|
end = text.rfind("}") + 1 # Include the closing brace
|
||||||
if start != -1 and end > start:
|
if start != -1 and end > start:
|
||||||
potential_json = text[start:end]
|
potential_json = text[start:end]
|
||||||
|
|
||||||
# Try all parsers on this extracted text
|
# Try all parsers on this extracted text
|
||||||
for parser_name, parser_func in [
|
for parser_name, parser_func in [("json", json.loads), ("json5", json5.loads), ("json_repair", json_repair.loads)]:
|
||||||
("json", json.loads),
|
|
||||||
("json5", json5.loads),
|
|
||||||
("json_repair", json_repair.loads)
|
|
||||||
]:
|
|
||||||
try:
|
try:
|
||||||
cleaned = self._clean_json_text(potential_json) if parser_name == "json" else potential_json
|
cleaned = self._clean_json_text(potential_json) if parser_name == "json" else potential_json
|
||||||
result = parser_func(cleaned)
|
result = parser_func(cleaned)
|
||||||
|
|
@ -280,24 +296,28 @@ class DiplomacyAgent:
|
||||||
logger.debug(f"[{self.power_name}] Fallback parse succeeded with {parser_name}, got dict.")
|
logger.debug(f"[{self.power_name}] Fallback parse succeeded with {parser_name}, got dict.")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Fallback parse with {parser_name} succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Fallback parse with {parser_name} succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"[{self.power_name}] Fallback {parser_name} failed: {e}")
|
logger.debug(f"[{self.power_name}] Fallback {parser_name} failed: {e}")
|
||||||
|
|
||||||
# If standard parsers failed, try aggressive cleaning
|
# If standard parsers failed, try aggressive cleaning
|
||||||
try:
|
try:
|
||||||
# Remove common non-JSON text that LLMs might add
|
# Remove common non-JSON text that LLMs might add
|
||||||
cleaned_text = re.sub(r'[^{}[\]"\',:.\d\w\s_-]', '', potential_json)
|
cleaned_text = re.sub(r'[^{}[\]"\',:.\d\w\s_-]', "", potential_json)
|
||||||
# Replace single quotes with double quotes (common LLM error)
|
# Replace single quotes with double quotes (common LLM error)
|
||||||
text_fixed = re.sub(r"'([^']*)':", r'"\1":', cleaned_text)
|
text_fixed = re.sub(r"'([^']*)':", r'"\1":', cleaned_text)
|
||||||
text_fixed = re.sub(r': *\'([^\']*)\'', r': "\1"', text_fixed)
|
text_fixed = re.sub(r": *\'([^\']*)\'", r': "\1"', text_fixed)
|
||||||
|
|
||||||
result = json.loads(text_fixed)
|
result = json.loads(text_fixed)
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
logger.debug(f"[{self.power_name}] Aggressive cleaning worked, got dict.")
|
logger.debug(f"[{self.power_name}] Aggressive cleaning worked, got dict.")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Aggressive cleaning worked, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Aggressive cleaning worked, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -311,7 +331,9 @@ class DiplomacyAgent:
|
||||||
logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, got dict.")
|
logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, got dict.")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{self.power_name}] Last resort json-repair succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}")
|
logger.warning(
|
||||||
|
f"[{self.power_name}] Last resort json-repair succeeded, but got type {type(result)} instead of dict. Content: {str(result)[:200]}"
|
||||||
|
)
|
||||||
# If even the last resort doesn't give a dict, return empty dict
|
# If even the last resort doesn't give a dict, return empty dict
|
||||||
return {}
|
return {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -324,8 +346,8 @@ class DiplomacyAgent:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# Remove trailing commas
|
# Remove trailing commas
|
||||||
text = re.sub(r',\s*}', '}', text)
|
text = re.sub(r",\s*}", "}", text)
|
||||||
text = re.sub(r',\s*]', ']', text)
|
text = re.sub(r",\s*]", "]", text)
|
||||||
|
|
||||||
# Fix newlines before JSON keys
|
# Fix newlines before JSON keys
|
||||||
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text)
|
text = re.sub(r'\n\s+"(\w+)"\s*:', r'"\1":', text)
|
||||||
|
|
@ -334,19 +356,18 @@ class DiplomacyAgent:
|
||||||
text = re.sub(r"'(\w+)'\s*:", r'"\1":', text)
|
text = re.sub(r"'(\w+)'\s*:", r'"\1":', text)
|
||||||
|
|
||||||
# Remove comments (if any)
|
# Remove comments (if any)
|
||||||
text = re.sub(r'//.*$', '', text, flags=re.MULTILINE)
|
text = re.sub(r"//.*$", "", text, flags=re.MULTILINE)
|
||||||
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
text = re.sub(r"/\*.*?\*/", "", text, flags=re.DOTALL)
|
||||||
|
|
||||||
# Fix unescaped quotes in values (basic attempt)
|
# Fix unescaped quotes in values (basic attempt)
|
||||||
# This is risky but sometimes helps with simple cases
|
# This is risky but sometimes helps with simple cases
|
||||||
text = re.sub(r':\s*"([^"]*)"([^",}\]]+)"', r': "\1\2"', text)
|
text = re.sub(r':\s*"([^"]*)"([^",}\]]+)"', r': "\1\2"', text)
|
||||||
|
|
||||||
# Remove any BOM or zero-width spaces
|
# Remove any BOM or zero-width spaces
|
||||||
text = text.replace('\ufeff', '').replace('\u200b', '')
|
text = text.replace("\ufeff", "").replace("\u200b", "")
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def add_journal_entry(self, entry: str):
|
def add_journal_entry(self, entry: str):
|
||||||
"""Adds a formatted entry string to the agent's private journal."""
|
"""Adds a formatted entry string to the agent's private journal."""
|
||||||
# Ensure entry is a string
|
# Ensure entry is a string
|
||||||
|
|
@ -366,7 +387,9 @@ class DiplomacyAgent:
|
||||||
# Also add to the context diary, which will be periodically rebuilt
|
# Also add to the context diary, which will be periodically rebuilt
|
||||||
self.private_diary.append(formatted_entry)
|
self.private_diary.append(formatted_entry)
|
||||||
|
|
||||||
logger.info(f"[{self.power_name}] DIARY ENTRY ADDED for {phase}. Total full entries: {len(self.full_private_diary)}. New entry: {entry[:100]}...")
|
logger.info(
|
||||||
|
f"[{self.power_name}] DIARY ENTRY ADDED for {phase}. Total full entries: {len(self.full_private_diary)}. New entry: {entry[:100]}..."
|
||||||
|
)
|
||||||
|
|
||||||
def format_private_diary_for_prompt(self) -> str:
|
def format_private_diary_for_prompt(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -407,14 +430,16 @@ class DiplomacyAgent:
|
||||||
if not formatted_diary:
|
if not formatted_diary:
|
||||||
return "(No diary entries to show)"
|
return "(No diary entries to show)"
|
||||||
|
|
||||||
logger.info(f"[{self.power_name}] Formatted diary with {1 if consolidated_entry else 0} consolidated and {len(recent_entries)} recent entries. Preview: {formatted_diary[:250]}...")
|
logger.info(
|
||||||
|
f"[{self.power_name}] Formatted diary with {1 if consolidated_entry else 0} consolidated and {len(recent_entries)} recent entries. Preview: {formatted_diary[:250]}..."
|
||||||
|
)
|
||||||
return formatted_diary
|
return formatted_diary
|
||||||
|
|
||||||
# The consolidate_entire_diary method has been moved to ai_diplomacy/diary_logic.py
|
# The consolidate_entire_diary method has been moved to ai_diplomacy/diary_logic.py
|
||||||
# to improve modularity and avoid circular dependencies.
|
# to improve modularity and avoid circular dependencies.
|
||||||
# It is now called as `run_diary_consolidation(agent, game, ...)` from the main game loop.
|
# It is now called as `run_diary_consolidation(agent, game, ...)` from the main game loop.
|
||||||
|
|
||||||
async def generate_negotiation_diary_entry(self, game: 'Game', game_history: GameHistory, log_file_path: str):
|
async def generate_negotiation_diary_entry(self, game: "Game", game_history: GameHistory, log_file_path: str):
|
||||||
"""
|
"""
|
||||||
Generates a diary entry summarizing negotiations and updates relationships.
|
Generates a diary entry summarizing negotiations and updates relationships.
|
||||||
This method now includes comprehensive LLM interaction logging.
|
This method now includes comprehensive LLM interaction logging.
|
||||||
|
|
@ -427,7 +452,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the prompt template file
|
# Load the prompt template file
|
||||||
prompt_template_content = _load_prompt_file(get_prompt_path('negotiation_diary_prompt.txt'), prompts_dir=self.prompts_dir)
|
prompt_template_content = _load_prompt_file(get_prompt_path("negotiation_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template_content:
|
if not prompt_template_content:
|
||||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('negotiation_diary_prompt.txt')}. Skipping diary entry.")
|
||||||
success_status = "Failure: Prompt file not loaded"
|
success_status = "Failure: Prompt file not loaded"
|
||||||
|
|
@ -437,12 +462,11 @@ class DiplomacyAgent:
|
||||||
board_state_dict = game.get_state()
|
board_state_dict = game.get_state()
|
||||||
board_state_str = f"Units: {board_state_dict.get('units', {})}, Centers: {board_state_dict.get('centers', {})}"
|
board_state_str = f"Units: {board_state_dict.get('units', {})}, Centers: {board_state_dict.get('centers', {})}"
|
||||||
|
|
||||||
messages_this_round = game_history.get_messages_this_round(
|
messages_this_round = game_history.get_messages_this_round(power_name=self.power_name, current_phase_name=game.current_short_phase)
|
||||||
power_name=self.power_name,
|
|
||||||
current_phase_name=game.current_short_phase
|
|
||||||
)
|
|
||||||
if not messages_this_round.strip() or messages_this_round.startswith("\n(No messages"):
|
if not messages_this_round.strip() or messages_this_round.startswith("\n(No messages"):
|
||||||
messages_this_round = "(No messages involving your power this round that require deep reflection for diary. Focus on overall situation.)"
|
messages_this_round = (
|
||||||
|
"(No messages involving your power this round that require deep reflection for diary. Focus on overall situation.)"
|
||||||
|
)
|
||||||
|
|
||||||
current_relationships_str = json.dumps(self.relationships)
|
current_relationships_str = json.dumps(self.relationships)
|
||||||
current_goals_str = json.dumps(self.goals)
|
current_goals_str = json.dumps(self.goals)
|
||||||
|
|
@ -462,28 +486,31 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
# Do aggressive preprocessing of the template to fix the problematic patterns
|
# Do aggressive preprocessing of the template to fix the problematic patterns
|
||||||
# This includes removing any newlines or whitespace before JSON keys that cause issues
|
# This includes removing any newlines or whitespace before JSON keys that cause issues
|
||||||
for pattern in ['negotiation_summary', 'updated_relationships', 'relationship_updates', 'intent']:
|
for pattern in ["negotiation_summary", "updated_relationships", "relationship_updates", "intent"]:
|
||||||
# Fix the "\n "key"" pattern that breaks .format()
|
# Fix the "\n "key"" pattern that breaks .format()
|
||||||
prompt_template_content = re.sub(
|
prompt_template_content = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', prompt_template_content)
|
||||||
fr'\n\s*"{pattern}"',
|
|
||||||
f'"{pattern}"',
|
|
||||||
prompt_template_content
|
|
||||||
)
|
|
||||||
|
|
||||||
# Escape all curly braces in JSON examples to prevent format() from interpreting them
|
# Escape all curly braces in JSON examples to prevent format() from interpreting them
|
||||||
# First, temporarily replace the actual template variables
|
# First, temporarily replace the actual template variables
|
||||||
temp_vars = ['power_name', 'current_phase', 'messages_this_round', 'agent_goals',
|
temp_vars = [
|
||||||
'agent_relationships', 'board_state_str', 'ignored_messages_context']
|
"power_name",
|
||||||
|
"current_phase",
|
||||||
|
"messages_this_round",
|
||||||
|
"agent_goals",
|
||||||
|
"agent_relationships",
|
||||||
|
"board_state_str",
|
||||||
|
"ignored_messages_context",
|
||||||
|
]
|
||||||
for var in temp_vars:
|
for var in temp_vars:
|
||||||
prompt_template_content = prompt_template_content.replace(f'{{{var}}}', f'<<{var}>>')
|
prompt_template_content = prompt_template_content.replace(f"{{{var}}}", f"<<{var}>>")
|
||||||
|
|
||||||
# Now escape all remaining braces (which should be JSON)
|
# Now escape all remaining braces (which should be JSON)
|
||||||
prompt_template_content = prompt_template_content.replace('{', '{{')
|
prompt_template_content = prompt_template_content.replace("{", "{{")
|
||||||
prompt_template_content = prompt_template_content.replace('}', '}}')
|
prompt_template_content = prompt_template_content.replace("}", "}}")
|
||||||
|
|
||||||
# Restore the template variables
|
# Restore the template variables
|
||||||
for var in temp_vars:
|
for var in temp_vars:
|
||||||
prompt_template_content = prompt_template_content.replace(f'<<{var}>>', f'{{{var}}}')
|
prompt_template_content = prompt_template_content.replace(f"<<{var}>>", f"{{{var}}}")
|
||||||
|
|
||||||
# Create a dictionary with safe values for formatting
|
# Create a dictionary with safe values for formatting
|
||||||
format_vars = {
|
format_vars = {
|
||||||
|
|
@ -495,7 +522,7 @@ class DiplomacyAgent:
|
||||||
"agent_goals": current_goals_str,
|
"agent_goals": current_goals_str,
|
||||||
"allowed_relationships_str": ", ".join(ALLOWED_RELATIONSHIPS),
|
"allowed_relationships_str": ", ".join(ALLOWED_RELATIONSHIPS),
|
||||||
"private_diary_summary": formatted_diary,
|
"private_diary_summary": formatted_diary,
|
||||||
"ignored_messages_context": ignored_context
|
"ignored_messages_context": ignored_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Now try to use the template after preprocessing
|
# Now try to use the template after preprocessing
|
||||||
|
|
@ -511,7 +538,6 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
|
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
|
logger.debug(f"[{self.power_name}] Negotiation diary prompt:\n{full_prompt[:500]}...")
|
||||||
|
|
||||||
raw_response = await run_llm_and_log(
|
raw_response = await run_llm_and_log(
|
||||||
|
|
@ -520,7 +546,7 @@ class DiplomacyAgent:
|
||||||
log_file_path=log_file_path, # Pass the main log file path
|
log_file_path=log_file_path, # Pass the main log file path
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='negotiation_diary_raw', # For run_llm_and_log context
|
response_type="negotiation_diary_raw", # For run_llm_and_log context
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[{self.power_name}] Raw negotiation diary response: {raw_response[:300]}...")
|
logger.debug(f"[{self.power_name}] Raw negotiation diary response: {raw_response[:300]}...")
|
||||||
|
|
@ -528,14 +554,14 @@ class DiplomacyAgent:
|
||||||
parsed_data = None
|
parsed_data = None
|
||||||
try:
|
try:
|
||||||
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
||||||
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Format the natural language response into JSON
|
# Format the natural language response into JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response,
|
||||||
FORMAT_NEGOTIATION_DIARY,
|
FORMAT_NEGOTIATION_DIARY,
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
log_file_path=log_file_path
|
log_file_path=log_file_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -554,17 +580,17 @@ class DiplomacyAgent:
|
||||||
if parsed_data:
|
if parsed_data:
|
||||||
# Fix 1: Be more robust about extracting the negotiation_summary field
|
# Fix 1: Be more robust about extracting the negotiation_summary field
|
||||||
diary_text_candidate = None
|
diary_text_candidate = None
|
||||||
for key in ['negotiation_summary', 'summary', 'diary_entry']:
|
for key in ["negotiation_summary", "summary", "diary_entry"]:
|
||||||
if key in parsed_data and isinstance(parsed_data[key], str) and parsed_data[key].strip():
|
if key in parsed_data and isinstance(parsed_data[key], str) and parsed_data[key].strip():
|
||||||
diary_text_candidate = parsed_data[key].strip()
|
diary_text_candidate = parsed_data[key].strip()
|
||||||
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for diary.")
|
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for diary.")
|
||||||
break
|
break
|
||||||
|
|
||||||
if 'intent' in parsed_data:
|
if "intent" in parsed_data:
|
||||||
if diary_text_candidate == None:
|
if diary_text_candidate == None:
|
||||||
diary_text_candidate = parsed_data['intent']
|
diary_text_candidate = parsed_data["intent"]
|
||||||
else:
|
else:
|
||||||
diary_text_candidate += '\nIntent: ' + parsed_data['intent']
|
diary_text_candidate += "\nIntent: " + parsed_data["intent"]
|
||||||
|
|
||||||
if diary_text_candidate:
|
if diary_text_candidate:
|
||||||
diary_entry_text = diary_text_candidate
|
diary_entry_text = diary_text_candidate
|
||||||
|
|
@ -574,7 +600,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
# Fix 2: Be more robust about extracting relationship updates
|
# Fix 2: Be more robust about extracting relationship updates
|
||||||
new_relationships = None
|
new_relationships = None
|
||||||
for key in ['relationship_updates', 'updated_relationships', 'relationships']:
|
for key in ["relationship_updates", "updated_relationships", "relationships"]:
|
||||||
if key in parsed_data and isinstance(parsed_data[key], dict):
|
if key in parsed_data and isinstance(parsed_data[key], dict):
|
||||||
new_relationships = parsed_data[key]
|
new_relationships = parsed_data[key]
|
||||||
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for relationship updates.")
|
logger.info(f"[{self.power_name}] Successfully extracted '{key}' for relationship updates.")
|
||||||
|
|
@ -595,7 +621,9 @@ class DiplomacyAgent:
|
||||||
for p_changed, new_r_val in valid_new_rels.items():
|
for p_changed, new_r_val in valid_new_rels.items():
|
||||||
old_r_val = self.relationships.get(p_changed, "Unknown")
|
old_r_val = self.relationships.get(p_changed, "Unknown")
|
||||||
if old_r_val != new_r_val:
|
if old_r_val != new_r_val:
|
||||||
logger.info(f"[{self.power_name}] Relationship with {p_changed} changing from {old_r_val} to {new_r_val} based on diary.")
|
logger.info(
|
||||||
|
f"[{self.power_name}] Relationship with {p_changed} changing from {old_r_val} to {new_r_val} based on diary."
|
||||||
|
)
|
||||||
self.relationships.update(valid_new_rels)
|
self.relationships.update(valid_new_rels)
|
||||||
relationships_updated = True
|
relationships_updated = True
|
||||||
success_status = "Success: Applied diary data (relationships updated)"
|
success_status = "Success: Applied diary data (relationships updated)"
|
||||||
|
|
@ -631,17 +659,17 @@ class DiplomacyAgent:
|
||||||
response_type="negotiation_diary", # Specific type for CSV logging
|
response_type="negotiation_diary", # Specific type for CSV logging
|
||||||
raw_input_prompt=full_prompt,
|
raw_input_prompt=full_prompt,
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_order_diary_entry(self, game: 'Game', orders: List[str], log_file_path: str):
|
async def generate_order_diary_entry(self, game: "Game", orders: List[str], log_file_path: str):
|
||||||
"""
|
"""
|
||||||
Generates a diary entry reflecting on the decided orders.
|
Generates a diary entry reflecting on the decided orders.
|
||||||
"""
|
"""
|
||||||
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating order diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the prompt template
|
# Load the prompt template
|
||||||
prompt_template = _load_prompt_file(get_prompt_path('order_diary_prompt.txt'), prompts_dir=self.prompts_dir)
|
prompt_template = _load_prompt_file(get_prompt_path("order_diary_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load {get_prompt_path('order_diary_prompt.txt')}. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -656,23 +684,22 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
# Do aggressive preprocessing on the template file
|
# Do aggressive preprocessing on the template file
|
||||||
# Fix any whitespace or formatting issues that could break .format()
|
# Fix any whitespace or formatting issues that could break .format()
|
||||||
for pattern in ['order_summary']:
|
for pattern in ["order_summary"]:
|
||||||
prompt_template = re.sub(fr'\n\s*"{pattern}"', f'"{pattern}"', prompt_template)
|
prompt_template = re.sub(rf'\n\s*"{pattern}"', f'"{pattern}"', prompt_template)
|
||||||
|
|
||||||
# Escape all curly braces in JSON examples to prevent format() from interpreting them
|
# Escape all curly braces in JSON examples to prevent format() from interpreting them
|
||||||
# First, temporarily replace the actual template variables
|
# First, temporarily replace the actual template variables
|
||||||
temp_vars = ['power_name', 'current_phase', 'orders_list_str', 'board_state_str',
|
temp_vars = ["power_name", "current_phase", "orders_list_str", "board_state_str", "agent_goals", "agent_relationships"]
|
||||||
'agent_goals', 'agent_relationships']
|
|
||||||
for var in temp_vars:
|
for var in temp_vars:
|
||||||
prompt_template = prompt_template.replace(f'{{{var}}}', f'<<{var}>>')
|
prompt_template = prompt_template.replace(f"{{{var}}}", f"<<{var}>>")
|
||||||
|
|
||||||
# Now escape all remaining braces (which should be JSON)
|
# Now escape all remaining braces (which should be JSON)
|
||||||
prompt_template = prompt_template.replace('{', '{{')
|
prompt_template = prompt_template.replace("{", "{{")
|
||||||
prompt_template = prompt_template.replace('}', '}}')
|
prompt_template = prompt_template.replace("}", "}}")
|
||||||
|
|
||||||
# Restore the template variables
|
# Restore the template variables
|
||||||
for var in temp_vars:
|
for var in temp_vars:
|
||||||
prompt_template = prompt_template.replace(f'<<{var}>>', f'{{{var}}}')
|
prompt_template = prompt_template.replace(f"<<{var}>>", f"{{{var}}}")
|
||||||
|
|
||||||
# Create a dictionary of variables for template formatting
|
# Create a dictionary of variables for template formatting
|
||||||
format_vars = {
|
format_vars = {
|
||||||
|
|
@ -681,7 +708,7 @@ class DiplomacyAgent:
|
||||||
"orders_list_str": orders_list_str,
|
"orders_list_str": orders_list_str,
|
||||||
"board_state_str": board_state_str,
|
"board_state_str": board_state_str,
|
||||||
"agent_goals": goals_str,
|
"agent_goals": goals_str,
|
||||||
"agent_relationships": relationships_str
|
"agent_relationships": relationships_str,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Try to use the template with proper formatting
|
# Try to use the template with proper formatting
|
||||||
|
|
@ -694,8 +721,6 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
logger.debug(f"[{self.power_name}] Order diary prompt:\n{prompt[:300]}...")
|
logger.debug(f"[{self.power_name}] Order diary prompt:\n{prompt[:300]}...")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
response_data = None
|
response_data = None
|
||||||
raw_response = None # Initialize raw_response
|
raw_response = None # Initialize raw_response
|
||||||
try:
|
try:
|
||||||
|
|
@ -705,7 +730,7 @@ class DiplomacyAgent:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='order_diary',
|
response_type="order_diary",
|
||||||
)
|
)
|
||||||
|
|
||||||
success_status = "FALSE"
|
success_status = "FALSE"
|
||||||
|
|
@ -718,11 +743,7 @@ class DiplomacyAgent:
|
||||||
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
|
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
|
||||||
# Format the natural language response into JSON
|
# Format the natural language response into JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response, FORMAT_ORDER_DIARY, power_name=self.power_name, phase=game.current_short_phase, log_file_path=log_file_path
|
||||||
FORMAT_ORDER_DIARY,
|
|
||||||
power_name=self.power_name,
|
|
||||||
phase=game.current_short_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -751,33 +772,35 @@ class DiplomacyAgent:
|
||||||
model_name=self.client.model_name,
|
model_name=self.client.model_name,
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='order_diary',
|
response_type="order_diary",
|
||||||
raw_input_prompt=prompt, # ENSURED
|
raw_input_prompt=prompt, # ENSURED
|
||||||
raw_response=raw_response if raw_response else "",
|
raw_response=raw_response if raw_response else "",
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success_status == "TRUE" and actual_diary_text:
|
if success_status == "TRUE" and actual_diary_text:
|
||||||
self.add_diary_entry(actual_diary_text, game.current_short_phase)
|
self.add_diary_entry(actual_diary_text, game.current_short_phase)
|
||||||
logger.info(f"[{self.power_name}] Order diary entry generated and added.")
|
logger.info(f"[{self.power_name}] Order diary entry generated and added.")
|
||||||
else:
|
else:
|
||||||
fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (LLM failed to generate a specific diary entry)"
|
fallback_diary = (
|
||||||
|
f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (LLM failed to generate a specific diary entry)"
|
||||||
|
)
|
||||||
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
||||||
logger.warning(f"[{self.power_name}] Failed to generate specific order diary entry. Added fallback.")
|
logger.warning(f"[{self.power_name}] Failed to generate specific order diary entry. Added fallback.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure prompt is defined or handled if it might not be (it should be in this flow)
|
# Ensure prompt is defined or handled if it might not be (it should be in this flow)
|
||||||
current_prompt = prompt if 'prompt' in locals() else "[prompt_unavailable_in_exception]"
|
current_prompt = prompt if "prompt" in locals() else "[prompt_unavailable_in_exception]"
|
||||||
current_raw_response = raw_response if 'raw_response' in locals() and raw_response is not None else f"Error: {e}"
|
current_raw_response = raw_response if "raw_response" in locals() and raw_response is not None else f"Error: {e}"
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=self.client.model_name if hasattr(self, 'client') else "UnknownModel",
|
model_name=self.client.model_name if hasattr(self, "client") else "UnknownModel",
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase if 'game' in locals() and hasattr(game, 'current_short_phase') else "order_phase",
|
phase=game.current_short_phase if "game" in locals() and hasattr(game, "current_short_phase") else "order_phase",
|
||||||
response_type='order_diary_exception',
|
response_type="order_diary_exception",
|
||||||
raw_input_prompt=current_prompt, # ENSURED (using current_prompt for safety)
|
raw_input_prompt=current_prompt, # ENSURED (using current_prompt for safety)
|
||||||
raw_response=current_raw_response,
|
raw_response=current_raw_response,
|
||||||
success="FALSE"
|
success="FALSE",
|
||||||
)
|
)
|
||||||
fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (Critical error in diary generation process)"
|
fallback_diary = f"Submitted orders for {game.current_short_phase}: {', '.join(orders)}. (Critical error in diary generation process)"
|
||||||
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
||||||
|
|
@ -785,12 +808,7 @@ class DiplomacyAgent:
|
||||||
# Rest of the code remains the same
|
# Rest of the code remains the same
|
||||||
|
|
||||||
async def generate_phase_result_diary_entry(
|
async def generate_phase_result_diary_entry(
|
||||||
self,
|
self, game: "Game", game_history: "GameHistory", phase_summary: str, all_orders: Dict[str, List[str]], log_file_path: str
|
||||||
game: 'Game',
|
|
||||||
game_history: 'GameHistory',
|
|
||||||
phase_summary: str,
|
|
||||||
all_orders: Dict[str, List[str]],
|
|
||||||
log_file_path: str
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a diary entry analyzing the actual phase results,
|
Generates a diary entry analyzing the actual phase results,
|
||||||
|
|
@ -799,7 +817,7 @@ class DiplomacyAgent:
|
||||||
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
logger.info(f"[{self.power_name}] Generating phase result diary entry for {game.current_short_phase}...")
|
||||||
|
|
||||||
# Load the template
|
# Load the template
|
||||||
prompt_template = _load_prompt_file('phase_result_diary_prompt.txt', prompts_dir=self.prompts_dir)
|
prompt_template = _load_prompt_file("phase_result_diary_prompt.txt", prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
logger.error(f"[{self.power_name}] Could not load phase_result_diary_prompt.txt. Skipping diary entry.")
|
||||||
return
|
return
|
||||||
|
|
@ -841,7 +859,7 @@ class DiplomacyAgent:
|
||||||
your_negotiations=your_negotiations,
|
your_negotiations=your_negotiations,
|
||||||
pre_phase_relationships=relationships_str,
|
pre_phase_relationships=relationships_str,
|
||||||
agent_goals=goals_str,
|
agent_goals=goals_str,
|
||||||
your_actual_orders=your_orders_str
|
your_actual_orders=your_orders_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[{self.power_name}] Phase result diary prompt:\n{prompt[:500]}...")
|
logger.debug(f"[{self.power_name}] Phase result diary prompt:\n{prompt[:500]}...")
|
||||||
|
|
@ -856,7 +874,7 @@ class DiplomacyAgent:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='phase_result_diary',
|
response_type="phase_result_diary",
|
||||||
)
|
)
|
||||||
|
|
||||||
if raw_response and raw_response.strip():
|
if raw_response and raw_response.strip():
|
||||||
|
|
@ -866,7 +884,9 @@ class DiplomacyAgent:
|
||||||
success_status = "TRUE"
|
success_status = "TRUE"
|
||||||
logger.info(f"[{self.power_name}] Phase result diary entry generated and added.")
|
logger.info(f"[{self.power_name}] Phase result diary entry generated and added.")
|
||||||
else:
|
else:
|
||||||
fallback_diary = f"Phase {game.current_short_phase} completed. Orders executed as: {your_orders_str}. (Failed to generate detailed analysis)"
|
fallback_diary = (
|
||||||
|
f"Phase {game.current_short_phase} completed. Orders executed as: {your_orders_str}. (Failed to generate detailed analysis)"
|
||||||
|
)
|
||||||
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
self.add_diary_entry(fallback_diary, game.current_short_phase)
|
||||||
logger.warning(f"[{self.power_name}] Empty response from LLM. Added fallback phase result diary.")
|
logger.warning(f"[{self.power_name}] Empty response from LLM. Added fallback phase result diary.")
|
||||||
success_status = "FALSE"
|
success_status = "FALSE"
|
||||||
|
|
@ -882,17 +902,19 @@ class DiplomacyAgent:
|
||||||
model_name=self.client.model_name,
|
model_name=self.client.model_name,
|
||||||
power_name=self.power_name,
|
power_name=self.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='phase_result_diary',
|
response_type="phase_result_diary",
|
||||||
raw_input_prompt=prompt,
|
raw_input_prompt=prompt,
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_state(self, prefix=""):
|
def log_state(self, prefix=""):
|
||||||
logger.debug(f"[{self.power_name}] {prefix} State: Goals={self.goals}, Relationships={self.relationships}")
|
logger.debug(f"[{self.power_name}] {prefix} State: Goals={self.goals}, Relationships={self.relationships}")
|
||||||
|
|
||||||
# Make this method async
|
# Make this method async
|
||||||
async def analyze_phase_and_update_state(self, game: 'Game', board_state: dict, phase_summary: str, game_history: 'GameHistory', log_file_path: str):
|
async def analyze_phase_and_update_state(
|
||||||
|
self, game: "Game", board_state: dict, phase_summary: str, game_history: "GameHistory", log_file_path: str
|
||||||
|
):
|
||||||
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
|
"""Analyzes the outcome of the last phase and updates goals/relationships using the LLM."""
|
||||||
# Use self.power_name internally
|
# Use self.power_name internally
|
||||||
power_name = self.power_name
|
power_name = self.power_name
|
||||||
|
|
@ -902,7 +924,7 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Construct the prompt using the unformatted state update prompt file
|
# 1. Construct the prompt using the unformatted state update prompt file
|
||||||
prompt_template = _load_prompt_file(get_prompt_path('state_update_prompt.txt'), prompts_dir=self.prompts_dir)
|
prompt_template = _load_prompt_file(get_prompt_path("state_update_prompt.txt"), prompts_dir=self.prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
|
logger.error(f"[{power_name}] Could not load {get_prompt_path('state_update_prompt.txt')}. Skipping state update.")
|
||||||
return
|
return
|
||||||
|
|
@ -944,10 +966,10 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
# Create a readable board state string from the board_state dict
|
# Create a readable board state string from the board_state dict
|
||||||
board_state_str = f"Board State:\n"
|
board_state_str = f"Board State:\n"
|
||||||
for p_name, power_data in board_state.get('powers', {}).items():
|
for p_name, power_data in board_state.get("powers", {}).items():
|
||||||
# Get units and centers from the board state
|
# Get units and centers from the board state
|
||||||
units = power_data.get('units', [])
|
units = power_data.get("units", [])
|
||||||
centers = power_data.get('centers', [])
|
centers = power_data.get("centers", [])
|
||||||
board_state_str += f" {p_name}: Units={units}, Centers={centers}\n"
|
board_state_str += f" {p_name}: Units={units}, Centers={centers}\n"
|
||||||
|
|
||||||
# Extract year from the phase name (e.g., "S1901M" -> "1901")
|
# Extract year from the phase name (e.g., "S1901M" -> "1901")
|
||||||
|
|
@ -961,7 +983,7 @@ class DiplomacyAgent:
|
||||||
phase_summary=last_phase_summary, # Use provided phase_summary
|
phase_summary=last_phase_summary, # Use provided phase_summary
|
||||||
other_powers=str(other_powers), # Pass as string representation
|
other_powers=str(other_powers), # Pass as string representation
|
||||||
current_goals="\n".join([f"- {g}" for g in self.goals]) if self.goals else "None",
|
current_goals="\n".join([f"- {g}" for g in self.goals]) if self.goals else "None",
|
||||||
current_relationships=str(self.relationships) if self.relationships else "None"
|
current_relationships=str(self.relationships) if self.relationships else "None",
|
||||||
)
|
)
|
||||||
logger.debug(f"[{power_name}] State update prompt:\n{prompt}")
|
logger.debug(f"[{power_name}] State update prompt:\n{prompt}")
|
||||||
|
|
||||||
|
|
@ -973,11 +995,11 @@ class DiplomacyAgent:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=current_phase,
|
phase=current_phase,
|
||||||
response_type='state_update',
|
response_type="state_update",
|
||||||
)
|
)
|
||||||
logger.debug(f"[{power_name}] Raw LLM response for state update: {response}")
|
logger.debug(f"[{power_name}] Raw LLM response for state update: {response}")
|
||||||
|
|
||||||
log_entry_response_type = 'state_update' # Default for log_llm_response
|
log_entry_response_type = "state_update" # Default for log_llm_response
|
||||||
log_entry_success = "FALSE" # Default
|
log_entry_success = "FALSE" # Default
|
||||||
update_data = None # Initialize
|
update_data = None # Initialize
|
||||||
|
|
||||||
|
|
@ -987,11 +1009,7 @@ class DiplomacyAgent:
|
||||||
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
|
if os.getenv("USE_UNFORMATTED_PROMPTS") == "1":
|
||||||
# Format the natural language response into JSON
|
# Format the natural language response into JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
response,
|
response, FORMAT_STATE_UPDATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
|
||||||
FORMAT_STATE_UPDATE,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=current_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -1007,29 +1025,31 @@ class DiplomacyAgent:
|
||||||
# Check if essential data ('updated_goals' or 'goals') is present AND is a list (for goals)
|
# Check if essential data ('updated_goals' or 'goals') is present AND is a list (for goals)
|
||||||
# For relationships, check for 'updated_relationships' or 'relationships' AND is a dict.
|
# For relationships, check for 'updated_relationships' or 'relationships' AND is a dict.
|
||||||
# Consider it TRUE if at least one of the primary data structures (goals or relationships) is present and correctly typed.
|
# Consider it TRUE if at least one of the primary data structures (goals or relationships) is present and correctly typed.
|
||||||
goals_present_and_valid = isinstance(update_data.get('updated_goals'), list) or isinstance(update_data.get('goals'), list)
|
goals_present_and_valid = isinstance(update_data.get("updated_goals"), list) or isinstance(update_data.get("goals"), list)
|
||||||
rels_present_and_valid = isinstance(update_data.get('updated_relationships'), dict) or isinstance(update_data.get('relationships'), dict)
|
rels_present_and_valid = isinstance(update_data.get("updated_relationships"), dict) or isinstance(
|
||||||
|
update_data.get("relationships"), dict
|
||||||
|
)
|
||||||
|
|
||||||
if update_data and (goals_present_and_valid or rels_present_and_valid):
|
if update_data and (goals_present_and_valid or rels_present_and_valid):
|
||||||
log_entry_success = "TRUE"
|
log_entry_success = "TRUE"
|
||||||
elif update_data: # Parsed, but maybe not all essential data there or not correctly typed
|
elif update_data: # Parsed, but maybe not all essential data there or not correctly typed
|
||||||
log_entry_success = "PARTIAL"
|
log_entry_success = "PARTIAL"
|
||||||
log_entry_response_type = 'state_update_partial_data'
|
log_entry_response_type = "state_update_partial_data"
|
||||||
else: # Parsed to None or empty dict/list, or data not in expected format
|
else: # Parsed to None or empty dict/list, or data not in expected format
|
||||||
log_entry_success = "FALSE"
|
log_entry_success = "FALSE"
|
||||||
log_entry_response_type = 'state_update_parsing_empty_or_invalid_data'
|
log_entry_response_type = "state_update_parsing_empty_or_invalid_data"
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"[{power_name}] Failed to parse JSON response for state update: {e}. Raw response: {response}")
|
logger.error(f"[{power_name}] Failed to parse JSON response for state update: {e}. Raw response: {response}")
|
||||||
log_entry_response_type = 'state_update_json_error'
|
log_entry_response_type = "state_update_json_error"
|
||||||
# log_entry_success remains "FALSE"
|
# log_entry_success remains "FALSE"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{power_name}] Unexpected error parsing state update: {e}")
|
logger.error(f"[{power_name}] Unexpected error parsing state update: {e}")
|
||||||
log_entry_response_type = 'state_update_unexpected_error'
|
log_entry_response_type = "state_update_unexpected_error"
|
||||||
update_data = {}
|
update_data = {}
|
||||||
# log_entry_success remains "FALSE"
|
# log_entry_success remains "FALSE"
|
||||||
else: # response was None or empty/whitespace
|
else: # response was None or empty/whitespace
|
||||||
logger.error(f"[{power_name}] No valid response (None or empty) received from LLM for state update.")
|
logger.error(f"[{power_name}] No valid response (None or empty) received from LLM for state update.")
|
||||||
log_entry_response_type = 'state_update_no_response'
|
log_entry_response_type = "state_update_no_response"
|
||||||
# log_entry_success remains "FALSE"
|
# log_entry_success remains "FALSE"
|
||||||
|
|
||||||
# Log the attempt and its outcome
|
# Log the attempt and its outcome
|
||||||
|
|
@ -1041,12 +1061,19 @@ class DiplomacyAgent:
|
||||||
response_type=log_entry_response_type,
|
response_type=log_entry_response_type,
|
||||||
raw_input_prompt=prompt, # ENSURED
|
raw_input_prompt=prompt, # ENSURED
|
||||||
raw_response=response if response is not None else "", # Handle if response is None
|
raw_response=response if response is not None else "", # Handle if response is None
|
||||||
success=log_entry_success
|
success=log_entry_success,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fallback logic if update_data is still None or not usable
|
# Fallback logic if update_data is still None or not usable
|
||||||
if not update_data or not (isinstance(update_data.get('updated_goals'), list) or isinstance(update_data.get('goals'), list) or isinstance(update_data.get('updated_relationships'), dict) or isinstance(update_data.get('relationships'), dict)):
|
if not update_data or not (
|
||||||
logger.warning(f"[{power_name}] update_data is None or missing essential valid structures after LLM call. Using existing goals and relationships as fallback.")
|
isinstance(update_data.get("updated_goals"), list)
|
||||||
|
or isinstance(update_data.get("goals"), list)
|
||||||
|
or isinstance(update_data.get("updated_relationships"), dict)
|
||||||
|
or isinstance(update_data.get("relationships"), dict)
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"[{power_name}] update_data is None or missing essential valid structures after LLM call. Using existing goals and relationships as fallback."
|
||||||
|
)
|
||||||
update_data = {
|
update_data = {
|
||||||
"updated_goals": self.goals,
|
"updated_goals": self.goals,
|
||||||
"updated_relationships": self.relationships,
|
"updated_relationships": self.relationships,
|
||||||
|
|
@ -1055,15 +1082,15 @@ class DiplomacyAgent:
|
||||||
|
|
||||||
# Check for both possible key names (prompt uses "goals"/"relationships",
|
# Check for both possible key names (prompt uses "goals"/"relationships",
|
||||||
# but code was expecting "updated_goals"/"updated_relationships")
|
# but code was expecting "updated_goals"/"updated_relationships")
|
||||||
updated_goals = update_data.get('updated_goals')
|
updated_goals = update_data.get("updated_goals")
|
||||||
if updated_goals is None:
|
if updated_goals is None:
|
||||||
updated_goals = update_data.get('goals')
|
updated_goals = update_data.get("goals")
|
||||||
if updated_goals is not None:
|
if updated_goals is not None:
|
||||||
logger.debug(f"[{power_name}] Using 'goals' key instead of 'updated_goals'")
|
logger.debug(f"[{power_name}] Using 'goals' key instead of 'updated_goals'")
|
||||||
|
|
||||||
updated_relationships = update_data.get('updated_relationships')
|
updated_relationships = update_data.get("updated_relationships")
|
||||||
if updated_relationships is None:
|
if updated_relationships is None:
|
||||||
updated_relationships = update_data.get('relationships')
|
updated_relationships = update_data.get("relationships")
|
||||||
if updated_relationships is not None:
|
if updated_relationships is not None:
|
||||||
logger.debug(f"[{power_name}] Using 'relationships' key instead of 'updated_relationships'")
|
logger.debug(f"[{power_name}] Using 'relationships' key instead of 'updated_relationships'")
|
||||||
|
|
||||||
|
|
@ -1104,7 +1131,9 @@ class DiplomacyAgent:
|
||||||
# Update relationships if the dictionary is not empty after validation
|
# Update relationships if the dictionary is not empty after validation
|
||||||
if valid_new_relationships:
|
if valid_new_relationships:
|
||||||
self.relationships.update(valid_new_relationships)
|
self.relationships.update(valid_new_relationships)
|
||||||
self.add_journal_entry(f"[{game.current_short_phase}] Relationships updated based on {last_phase_name}: {valid_new_relationships}")
|
self.add_journal_entry(
|
||||||
|
f"[{game.current_short_phase}] Relationships updated based on {last_phase_name}: {valid_new_relationships}"
|
||||||
|
)
|
||||||
elif updated_relationships: # Log if the original dict wasn't empty but validation removed everything
|
elif updated_relationships: # Log if the original dict wasn't empty but validation removed everything
|
||||||
logger.warning(f"[{power_name}] Found relationships in LLM response but none were valid after normalization. Using defaults.")
|
logger.warning(f"[{power_name}] Found relationships in LLM response but none were valid after normalization. Using defaults.")
|
||||||
else: # Log if the original dict was empty
|
else: # Log if the original dict was empty
|
||||||
|
|
@ -1145,7 +1174,7 @@ class DiplomacyAgent:
|
||||||
# summary += f"\n Last Journal Entry: {self.private_journal[-1]}"
|
# summary += f"\n Last Journal Entry: {self.private_journal[-1]}"
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def generate_plan(self, game: Game, board_state: dict, game_history: 'GameHistory') -> str:
|
def generate_plan(self, game: Game, board_state: dict, game_history: "GameHistory") -> str:
|
||||||
"""Generates a strategic plan using the client and logs it."""
|
"""Generates a strategic plan using the client and logs it."""
|
||||||
logger.info(f"Agent {self.power_name} generating strategic plan...")
|
logger.info(f"Agent {self.power_name} generating strategic plan...")
|
||||||
try:
|
try:
|
||||||
|
|
@ -1157,3 +1186,4 @@ class DiplomacyAgent:
|
||||||
logger.error(f"Agent {self.power_name} failed to generate plan: {e}")
|
logger.error(f"Agent {self.power_name} failed to generate plan: {e}")
|
||||||
self.add_journal_entry(f"Failed to generate plan for phase {game.current_phase} due to error: {e}")
|
self.add_journal_entry(f"Failed to generate plan for phase {game.current_phase} due to error: {e}")
|
||||||
return "Error: Failed to generate plan."
|
return "Error: Failed to generate plan."
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from collections import defaultdict, Counter
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
import statistics
|
import statistics
|
||||||
|
from ..models import PowerEnum
|
||||||
|
|
||||||
class StatisticalGameAnalyzer:
|
class StatisticalGameAnalyzer:
|
||||||
"""Production-ready analyzer for AI Diplomacy game statistics.
|
"""Production-ready analyzer for AI Diplomacy game statistics.
|
||||||
|
|
@ -47,7 +48,6 @@ class StatisticalGameAnalyzer:
|
||||||
'Ally': 2
|
'Ally': 2
|
||||||
}
|
}
|
||||||
|
|
||||||
DIPLOMACY_POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
|
||||||
|
|
||||||
# Complete list of response types found in actual data
|
# Complete list of response types found in actual data
|
||||||
RESPONSE_TYPES = [
|
RESPONSE_TYPES = [
|
||||||
|
|
@ -59,44 +59,8 @@ class StatisticalGameAnalyzer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize analyzer with configuration constants."""
|
"""Initialize analyzer with configuration constants."""
|
||||||
self.relationship_values = self.RELATIONSHIP_VALUES
|
self.relationship_values = self.RELATIONSHIP_VALUES
|
||||||
self.powers = self.DIPLOMACY_POWERS
|
|
||||||
|
|
||||||
def _normalize_recipient_name(self, recipient: str) -> str:
|
|
||||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
|
||||||
if not recipient:
|
|
||||||
return recipient
|
|
||||||
|
|
||||||
recipient = recipient.upper().strip()
|
|
||||||
|
|
||||||
# Handle common LLM typos and abbreviations found in data
|
|
||||||
name_mapping = {
|
|
||||||
'EGMANY': 'GERMANY',
|
|
||||||
'GERMAN': 'GERMANY',
|
|
||||||
'UK': 'ENGLAND',
|
|
||||||
'BRIT': 'ENGLAND',
|
|
||||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
|
||||||
'FRANCE': 'FRANCE', # Keep as-is
|
|
||||||
'GERMANY': 'GERMANY', # Keep as-is
|
|
||||||
'ITALY': 'ITALY', # Keep as-is
|
|
||||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
|
||||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
|
||||||
'TURKEY': 'TURKEY', # Keep as-is
|
|
||||||
'Germany': 'GERMANY',
|
|
||||||
'England': 'ENGLAND',
|
|
||||||
'France': 'FRANCE',
|
|
||||||
'Italy': 'ITALY',
|
|
||||||
'Russia': 'RUSSIA',
|
|
||||||
'Austria': 'AUSTRIA',
|
|
||||||
'Turkey': 'TURKEY',
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized = name_mapping.get(recipient, recipient)
|
|
||||||
|
|
||||||
# Validate it's a known power
|
|
||||||
if normalized not in self.DIPLOMACY_POWERS:
|
|
||||||
return None # Invalid recipient
|
|
||||||
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
|
def analyze_folder(self, folder_path: str, output_dir: str = None) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -272,11 +236,11 @@ class StatisticalGameAnalyzer:
|
||||||
if not phase_data:
|
if not phase_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for power in self.powers:
|
for power in PowerEnum:
|
||||||
for response_type in response_types:
|
for response_type in response_types:
|
||||||
# Extract features for this specific power/phase/response_type combination
|
# Extract features for this specific power/phase/response_type combination
|
||||||
features = self._extract_power_phase_response_features(
|
features = self._extract_power_phase_response_features(
|
||||||
power, phase_name, response_type, llm_responses, phase_data, game_data
|
power.value, phase_name, response_type, llm_responses, phase_data, game_data
|
||||||
)
|
)
|
||||||
if features:
|
if features:
|
||||||
phase_features.append(features)
|
phase_features.append(features)
|
||||||
|
|
@ -530,7 +494,7 @@ class StatisticalGameAnalyzer:
|
||||||
game_features = []
|
game_features = []
|
||||||
game_scores = self._compute_game_scores(game_data)
|
game_scores = self._compute_game_scores(game_data)
|
||||||
|
|
||||||
for power in self.powers:
|
for power in PowerEnum:
|
||||||
features = {
|
features = {
|
||||||
# === IDENTIFIERS ===
|
# === IDENTIFIERS ===
|
||||||
'game_id': game_data.get('id', 'unknown'),
|
'game_id': game_data.get('id', 'unknown'),
|
||||||
|
|
@ -718,7 +682,8 @@ class StatisticalGameAnalyzer:
|
||||||
|
|
||||||
# Categorize by relationship
|
# Categorize by relationship
|
||||||
recipient = msg.get('recipient_power')
|
recipient = msg.get('recipient_power')
|
||||||
normalized_recipient = self._normalize_recipient_name(recipient)
|
# This will coerce some known aliases to match the 7 acceptable names (see models.py)
|
||||||
|
normalized_recipient = PowerEnum(recipient)
|
||||||
|
|
||||||
# Skip self-messages and invalid recipients
|
# Skip self-messages and invalid recipients
|
||||||
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
if normalized_recipient and normalized_recipient != power and normalized_recipient in relationships:
|
||||||
|
|
@ -969,7 +934,7 @@ class StatisticalGameAnalyzer:
|
||||||
break
|
break
|
||||||
|
|
||||||
# elimination turn for every power
|
# elimination turn for every power
|
||||||
elim_turn: dict[str, int | None] = {p: None for p in self.DIPLOMACY_POWERS}
|
elim_turn: dict[str, int | None] = {p: None for p in [power.value for power in PowerEnum]}
|
||||||
for idx, ph in enumerate(phases):
|
for idx, ph in enumerate(phases):
|
||||||
yr = self._phase_year(phases, idx)
|
yr = self._phase_year(phases, idx)
|
||||||
if yr is None:
|
if yr is None:
|
||||||
|
|
|
||||||
48
bot_client/config.py
Normal file
48
bot_client/config.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import datetime
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
class Configuration(BaseSettings):
|
||||||
|
DEBUG: bool = False
|
||||||
|
log_file_path: Path | None = None
|
||||||
|
DEEPSEEK_API_KEY: str | None = None
|
||||||
|
OPENAI_API_KEY: str | None = None
|
||||||
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
|
GEMINI_API_KEY: str | None = None
|
||||||
|
OPENROUTER_API_KEY: str | None = None
|
||||||
|
|
||||||
|
def __init__(self, power_name, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/{power_name}.txt")
|
||||||
|
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||||
|
self.log_file_path = self.log_file_path.resolve()
|
||||||
|
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.log_file_path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
self._validate_api_keys()
|
||||||
|
|
||||||
|
def _validate_api_keys(self):
|
||||||
|
"""Validate API keys at startup and issue warnings for missing keys"""
|
||||||
|
api_keys = [
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in api_keys:
|
||||||
|
value = super().__getattribute__(key)
|
||||||
|
if not value or (isinstance(value, str) and len(value) == 0):
|
||||||
|
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
"""Override to check for empty API keys at access time"""
|
||||||
|
value = super().__getattribute__(name)
|
||||||
|
|
||||||
|
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||||
|
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||||
|
|
||||||
|
return value
|
||||||
60
config.py
Normal file
60
config.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
from models import PowerEnum
|
||||||
|
|
||||||
|
|
||||||
|
class Configuration(BaseSettings):
|
||||||
|
DEBUG: bool = False
|
||||||
|
log_file_path: Path | None = None
|
||||||
|
USE_UNFORMATTED_PROMPTS: bool = False
|
||||||
|
|
||||||
|
# API Keys to be validated. Warns if they aren't present at startup, raises ValueError if you attempt to use them when they aren't present.
|
||||||
|
DEEPSEEK_API_KEY: str | None = None
|
||||||
|
OPENAI_API_KEY: str | None = None
|
||||||
|
ANTHROPIC_API_KEY: str | None = None
|
||||||
|
GEMINI_API_KEY: str | None = None
|
||||||
|
OPENROUTER_API_KEY: str | None = None
|
||||||
|
|
||||||
|
def __init__(self, power_name: Optional[PowerEnum] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Add a '-POWER' to the end of the file name if it's for a specific power
|
||||||
|
log_power_path = "-" + power_name if power_name else None
|
||||||
|
self.log_file_path = Path(f"./logs/{datetime.datetime.now().strftime('%d-%m-%y_%H:%M')}/logs{log_power_path} .txt")
|
||||||
|
# Make the path absolute, gets rid of weirdness of calling this in different places
|
||||||
|
self.log_file_path = self.log_file_path.resolve()
|
||||||
|
self.log_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.log_file_path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
self._validate_api_keys()
|
||||||
|
|
||||||
|
def _validate_api_keys(self):
|
||||||
|
"""Validate API keys at startup and issue warnings for missing keys"""
|
||||||
|
api_keys = [
|
||||||
|
"DEEPSEEK_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"GEMINI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in api_keys:
|
||||||
|
value = super().__getattribute__(key)
|
||||||
|
if not value or (isinstance(value, str) and len(value) == 0):
|
||||||
|
warnings.warn(f"API key '{key}' is not set or is empty", UserWarning)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
"""Override to check for empty API keys at access time"""
|
||||||
|
value = super().__getattribute__(name)
|
||||||
|
|
||||||
|
# If this is a _KEY, it must be not None, string, and length > 0 to return. We do not validate the correctness of the key.
|
||||||
|
# e.g. "thisIsAKey" is valid in this sense.
|
||||||
|
if name.endswith("_KEY") and (not value or (isinstance(value, str) and len(value) == 0)):
|
||||||
|
raise ValueError(f"API key '{name}' is not set or is empty. Please configure it before use.")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
config = Configuration()
|
||||||
38
models.py
Normal file
38
models.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# your “typo → canonical” map
|
||||||
|
_POWER_ALIASES = {
|
||||||
|
"EGMANY": "GERMANY",
|
||||||
|
"GERMAN": "GERMANY",
|
||||||
|
"UK": "ENGLAND",
|
||||||
|
"BRIT": "ENGLAND",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PowerEnum(str, Enum):
|
||||||
|
AUSTRIA = "AUSTRIA"
|
||||||
|
ENGLAND = "ENGLAND"
|
||||||
|
FRANCE = "FRANCE"
|
||||||
|
GERMANY = "GERMANY"
|
||||||
|
ITALY = "ITALY"
|
||||||
|
RUSSIA = "RUSSIA"
|
||||||
|
TURKEY = "TURKEY"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: Any) -> Optional["Enum"]:
|
||||||
|
"""
|
||||||
|
Called when you do PowerEnum(value) and `value` isn't one of the raw enum values.
|
||||||
|
Here we normalize strings to upper‐stripped, apply aliases, then retry.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.upper().strip()
|
||||||
|
# apply any synonyms/typos
|
||||||
|
normalized = _POWER_ALIASES.get(normalized, normalized)
|
||||||
|
# look up in the normal value→member map
|
||||||
|
member = cls._value2member_map_.get(normalized)
|
||||||
|
if member is not None:
|
||||||
|
return member
|
||||||
|
|
||||||
|
# by default, let Enum raise the ValueError
|
||||||
|
return super()._missing_(value)
|
||||||
|
|
@ -15,6 +15,7 @@ dependencies = [
|
||||||
"json5>=0.12.0",
|
"json5>=0.12.0",
|
||||||
"matplotlib>=3.10.3",
|
"matplotlib>=3.10.3",
|
||||||
"openai>=1.90.0",
|
"openai>=1.90.0",
|
||||||
|
"pydantic-settings>=2.10.1",
|
||||||
"pylint>=2.3.0",
|
"pylint>=2.3.0",
|
||||||
"pytest>=4.4.0",
|
"pytest>=4.4.0",
|
||||||
"pytest-xdist>=3.7.0",
|
"pytest-xdist>=3.7.0",
|
||||||
|
|
@ -30,3 +31,6 @@ dependencies = [
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"ujson>=5.10.0",
|
"ujson>=5.10.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 150
|
||||||
|
|
|
||||||
21
tests/test_models.py
Normal file
21
tests/test_models.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from models import PowerEnum
|
||||||
|
|
||||||
|
|
||||||
|
def test_power_name_aliases():
|
||||||
|
# Test all aliases defined in _POWER_ALIASES
|
||||||
|
assert PowerEnum("UK") == PowerEnum.ENGLAND
|
||||||
|
assert PowerEnum("BRIT") == PowerEnum.ENGLAND
|
||||||
|
assert PowerEnum("EGMANY") == PowerEnum.GERMANY
|
||||||
|
assert PowerEnum("GERMAN") == PowerEnum.GERMANY
|
||||||
|
|
||||||
|
# Test direct enum values (no alias needed)
|
||||||
|
assert PowerEnum("AUSTRIA") == PowerEnum.AUSTRIA
|
||||||
|
assert PowerEnum("FRANCE") == PowerEnum.FRANCE
|
||||||
|
|
||||||
|
# Test case insensitivity
|
||||||
|
assert PowerEnum("france") == PowerEnum.FRANCE
|
||||||
|
assert PowerEnum("iTaLy") == PowerEnum.ITALY
|
||||||
|
|
||||||
|
# Test with whitespace
|
||||||
|
assert PowerEnum(" RUSSIA ") == PowerEnum.RUSSIA
|
||||||
|
assert PowerEnum("TURKEY ") == PowerEnum.TURKEY
|
||||||
Loading…
Add table
Add a link
Reference in a new issue