mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
367 lines
16 KiB
Python
367 lines
16 KiB
Python
from dotenv import load_dotenv
|
|
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional
|
|
|
|
logger = logging.getLogger("utils")
|
|
logger.setLevel(logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
load_dotenv()
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
sender: str
|
|
recipient: str
|
|
content: str
|
|
|
|
|
|
@dataclass
|
|
class Phase:
|
|
name: str # e.g. "SPRING 1901"
|
|
plans: Dict[str, str] = field(default_factory=dict)
|
|
messages: List[Message] = field(default_factory=list)
|
|
orders_by_power: Dict[str, List[str]] = field(
|
|
default_factory=lambda: defaultdict(list)
|
|
)
|
|
results_by_power: Dict[str, List[List[str]]] = field(
|
|
default_factory=lambda: defaultdict(list)
|
|
)
|
|
# NEW: Store phase-end summaries provided by each power
|
|
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
|
# NEW: Store experience/journal updates from each power for this phase
|
|
experience_updates: Dict[str, str] = field(default_factory=dict)
|
|
|
|
def add_plan(self, power_name: str, plan: str):
|
|
self.plans[power_name] = plan
|
|
|
|
def add_message(self, sender: str, recipient: str, content: str):
|
|
self.messages.append(
|
|
Message(sender=sender, recipient=recipient, content=content)
|
|
)
|
|
|
|
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
|
self.orders_by_power[power].extend(orders)
|
|
# Make sure results has the same length as orders, if not, pad with empty lists
|
|
if len(results) < len(orders):
|
|
results.extend([[] for _ in range(len(orders) - len(results))])
|
|
self.results_by_power[power].extend(results)
|
|
|
|
def get_global_messages(self) -> str:
|
|
result = ""
|
|
for msg in self.messages:
|
|
if msg.recipient == "GLOBAL":
|
|
result += f" {msg.sender}: {msg.content}\n"
|
|
return result
|
|
|
|
def get_private_messages(self, power: str) -> Dict[str, str]:
|
|
conversations = defaultdict(str)
|
|
for msg in self.messages:
|
|
if msg.sender == power and msg.recipient != "GLOBAL":
|
|
conversations[msg.recipient] += f" {power}: {msg.content}\n"
|
|
elif msg.recipient == power:
|
|
conversations[msg.sender] += f" {msg.sender}: {msg.content}\n"
|
|
return conversations
|
|
|
|
def get_all_orders_formatted(self) -> str:
|
|
if not self.orders_by_power:
|
|
return ""
|
|
|
|
result = f"\nOrders for {self.name}:\n"
|
|
for power, orders in self.orders_by_power.items():
|
|
result += f"{power}:\n"
|
|
results = self.results_by_power.get(power, [])
|
|
for i, order in enumerate(orders):
|
|
if i < len(results) and results[i]:
|
|
# Join multiple results with commas
|
|
result_str = f" ({', '.join(results[i])})"
|
|
else:
|
|
result_str = " (successful)"
|
|
result += f" {order}{result_str}\n"
|
|
result += "\n"
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class GameHistory:
|
|
phases: List[Phase] = field(default_factory=list)
|
|
|
|
def add_phase(self, phase_name: str):
|
|
# Avoid adding duplicate phases
|
|
if not self.phases or self.phases[-1].name != phase_name:
|
|
self.phases.append(Phase(name=phase_name))
|
|
logger.debug(f"Added new phase: {phase_name}")
|
|
else:
|
|
logger.warning(f"Phase {phase_name} already exists. Not adding again.")
|
|
|
|
def _get_phase(self, phase_name: str) -> Optional[Phase]:
|
|
for phase in reversed(self.phases):
|
|
if phase.name == phase_name:
|
|
return phase
|
|
logger.error(f"Phase {phase_name} not found in history.")
|
|
return None
|
|
|
|
def add_plan(self, phase_name: str, power_name: str, plan: str):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
phase.plans[power_name] = plan
|
|
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
|
|
|
def add_message(
|
|
self, phase_name: str, sender: str, recipient: str, message_content: str
|
|
):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
message = Message(
|
|
sender=sender, recipient=recipient, content=message_content
|
|
)
|
|
phase.messages.append(message)
|
|
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
|
|
|
def add_orders(self, phase_name: str, power_name: str, orders: List[str]):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
phase.orders_by_power[power_name].extend(orders)
|
|
logger.debug(f"Added orders for {power_name} in {phase_name}: {orders}")
|
|
|
|
def add_results(self, phase_name: str, power_name: str, results: List[List[str]]):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
phase.results_by_power[power_name].extend(results)
|
|
logger.debug(f"Added results for {power_name} in {phase_name}: {results}")
|
|
|
|
# NEW: Method to add phase summary for a power
|
|
def add_phase_summary(self, phase_name: str, power_name: str, summary: str):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
phase.phase_summaries[power_name] = summary
|
|
logger.debug(f"Added phase summary for {power_name} in {phase_name}")
|
|
|
|
# NEW: Method to add experience update for a power
|
|
def add_experience_update(self, phase_name: str, power_name: str, update: str):
|
|
phase = self._get_phase(phase_name)
|
|
if phase:
|
|
phase.experience_updates[power_name] = update
|
|
logger.debug(f"Added experience update for {power_name} in {phase_name}")
|
|
|
|
def get_strategic_directives(self):
|
|
# returns for last phase only if exists
|
|
if not self.phases:
|
|
return {}
|
|
return self.phases[-1].plans
|
|
|
|
# NEW METHOD
|
|
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
|
current_phase: Optional[Phase] = None
|
|
for phase_obj in self.phases:
|
|
if phase_obj.name == current_phase_name:
|
|
current_phase = phase_obj
|
|
break
|
|
|
|
if not current_phase:
|
|
return f"\n(No messages found for current phase: {current_phase_name})\n"
|
|
|
|
messages_str = ""
|
|
|
|
global_msgs_content = current_phase.get_global_messages()
|
|
if global_msgs_content:
|
|
messages_str += "**GLOBAL MESSAGES THIS ROUND:**\n"
|
|
messages_str += global_msgs_content
|
|
else:
|
|
messages_str += "**GLOBAL MESSAGES THIS ROUND:**\n (No global messages this round)\n"
|
|
|
|
private_msgs_dict = current_phase.get_private_messages(power_name)
|
|
if private_msgs_dict:
|
|
messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n"
|
|
for other_power, conversation_content in private_msgs_dict.items():
|
|
messages_str += f" Conversation with {other_power}:\n"
|
|
messages_str += conversation_content
|
|
messages_str += "\n"
|
|
else:
|
|
messages_str += "\n**PRIVATE MESSAGES TO/FROM YOU THIS ROUND:**\n (No private messages this round)\n"
|
|
|
|
if not global_msgs_content and not private_msgs_dict:
|
|
return f"\n(No messages recorded for current phase: {current_phase_name})\n"
|
|
|
|
return messages_str.strip()
|
|
|
|
# New method to get recent messages TO a specific power
|
|
def get_recent_messages_to_power(self, power_name: str, limit: int = 3) -> List[Dict[str, str]]:
|
|
"""
|
|
Gets the most recent messages sent TO this power, useful for tracking messages that need replies.
|
|
Returns a list of dictionaries with 'sender', 'content', and 'phase' keys.
|
|
"""
|
|
if not self.phases:
|
|
return []
|
|
|
|
# Get the most recent 2 phases including current phase
|
|
recent_phases = self.phases[-2:] if len(self.phases) >= 2 else self.phases[-1:]
|
|
|
|
# Collect all messages sent TO this power
|
|
messages_to_power = []
|
|
for phase in recent_phases:
|
|
for msg in phase.messages:
|
|
# Personal messages to this power or global messages from others
|
|
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
|
# Skip if sender is this power (don't need to respond to own messages)
|
|
if msg.sender != power_name:
|
|
messages_to_power.append({
|
|
'sender': msg.sender,
|
|
'content': msg.content,
|
|
'phase': phase.name
|
|
})
|
|
|
|
# Add debug logging
|
|
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
|
if not messages_to_power:
|
|
logger.info(f"No messages found for {power_name} to respond to")
|
|
|
|
# Take the most recent 'limit' messages
|
|
return messages_to_power[-limit:] if messages_to_power else []
|
|
|
|
def get_ignored_messages_by_power(self, sender_name: str, num_phases: int = 3) -> Dict[str, List[Dict[str, str]]]:
|
|
"""
|
|
Identifies which powers are not responding to messages from sender_name.
|
|
Returns a dict mapping power names to their ignored messages.
|
|
|
|
A message is considered ignored if:
|
|
1. It was sent from sender_name to another power (private)
|
|
2. No response from that power was received in the same or next phase
|
|
"""
|
|
ignored_by_power = {}
|
|
|
|
# Get recent phases
|
|
recent_phases = self.phases[-num_phases:] if self.phases else []
|
|
if not recent_phases:
|
|
return ignored_by_power
|
|
|
|
for i, phase in enumerate(recent_phases):
|
|
# Get messages sent by sender to specific powers (not global)
|
|
sender_messages = []
|
|
for msg in phase.messages:
|
|
# Handle both Message objects and dict objects
|
|
if isinstance(msg, Message):
|
|
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
|
|
sender_messages.append(msg)
|
|
else: # Assume dict
|
|
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
|
|
sender_messages.append(msg)
|
|
|
|
# Check for responses in this and next phases
|
|
for msg in sender_messages:
|
|
# Handle both Message objects and dict objects
|
|
if isinstance(msg, Message):
|
|
recipient = msg.recipient
|
|
msg_content = msg.content
|
|
else:
|
|
recipient = msg['recipient']
|
|
msg_content = msg['content']
|
|
|
|
# Look for responses in current phase and next phases
|
|
found_response = False
|
|
|
|
# Check remaining phases starting from current
|
|
for check_phase in recent_phases[i:min(i+2, len(recent_phases))]:
|
|
# Look for messages FROM the recipient TO the sender (direct response)
|
|
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
|
|
response_msgs = []
|
|
for m in check_phase.messages:
|
|
if isinstance(m, Message):
|
|
if m.sender == recipient and (m.recipient == sender_name or
|
|
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
|
|
response_msgs.append(m)
|
|
else: # Assume dict
|
|
if m['sender'] == recipient and (m['recipient'] == sender_name or
|
|
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
|
|
response_msgs.append(m)
|
|
|
|
if response_msgs:
|
|
found_response = True
|
|
break
|
|
|
|
if not found_response:
|
|
if recipient not in ignored_by_power:
|
|
ignored_by_power[recipient] = []
|
|
ignored_by_power[recipient].append({
|
|
'phase': phase.name,
|
|
'content': msg_content
|
|
})
|
|
|
|
return ignored_by_power
|
|
|
|
# MODIFIED METHOD (renamed from get_game_history)
|
|
def get_previous_phases_history(
|
|
self, power_name: str, current_phase_name: str, include_plans: bool = True, num_prev_phases: int = 5
|
|
) -> str:
|
|
if not self.phases:
|
|
return "\n(No game history available)\n"
|
|
|
|
relevant_phases = [p for p in self.phases if p.name != current_phase_name]
|
|
|
|
if not relevant_phases:
|
|
return "\n(No previous game history before this round)\n"
|
|
|
|
phases_to_report = relevant_phases[-num_prev_phases:]
|
|
|
|
if not phases_to_report:
|
|
return "\n(No previous game history available within the lookback window)\n"
|
|
|
|
game_history_str = ""
|
|
|
|
for phase_idx, phase in enumerate(phases_to_report):
|
|
phase_content_str = f"\nPHASE: {phase.name}\n"
|
|
current_phase_has_content = False
|
|
|
|
global_msgs = phase.get_global_messages()
|
|
if global_msgs:
|
|
phase_content_str += "\n GLOBAL MESSAGES:\n"
|
|
phase_content_str += "".join([f" {line}\n" for line in global_msgs.strip().split('\n')])
|
|
current_phase_has_content = True
|
|
|
|
private_msgs = phase.get_private_messages(power_name)
|
|
if private_msgs:
|
|
phase_content_str += "\n PRIVATE MESSAGES:\n"
|
|
for other_power, messages in private_msgs.items():
|
|
phase_content_str += f" Conversation with {other_power}:\n"
|
|
phase_content_str += "".join([f" {line}\n" for line in messages.strip().split('\n')])
|
|
current_phase_has_content = True
|
|
|
|
if phase.orders_by_power:
|
|
phase_content_str += "\n ORDERS:\n"
|
|
for power, orders in phase.orders_by_power.items():
|
|
indicator = " (your power)" if power == power_name else ""
|
|
phase_content_str += f" {power}{indicator}:\n"
|
|
results = phase.results_by_power.get(power, [])
|
|
for i, order in enumerate(orders):
|
|
result_str = " (successful)"
|
|
if i < len(results) and results[i] and not all(r == "" for r in results[i]):
|
|
result_str = f" ({', '.join(results[i])})"
|
|
phase_content_str += f" {order}{result_str}\n"
|
|
phase_content_str += "\n"
|
|
current_phase_has_content = True
|
|
|
|
if current_phase_has_content:
|
|
if not game_history_str:
|
|
game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n"
|
|
game_history_str += phase_content_str
|
|
if phase_idx < len(phases_to_report) -1 :
|
|
game_history_str += " " + "-" * 48 + "\n"
|
|
|
|
if include_plans and phases_to_report:
|
|
last_reported_previous_phase = phases_to_report[-1]
|
|
if last_reported_previous_phase.plans:
|
|
if not game_history_str:
|
|
game_history_str = "**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n"
|
|
game_history_str += f"\n PLANS SUBMITTED FOR PHASE {last_reported_previous_phase.name}:\n"
|
|
if power_name in last_reported_previous_phase.plans:
|
|
game_history_str += f" Your Plan: {last_reported_previous_phase.plans[power_name]}\n"
|
|
for p_other, plan_other in last_reported_previous_phase.plans.items():
|
|
if p_other != power_name:
|
|
game_history_str += f" {p_other}'s Plan: {plan_other}\n"
|
|
game_history_str += "\n"
|
|
|
|
if not game_history_str.replace("**PREVIOUS GAME HISTORY (Messages, Orders, & Plans from older rounds & phases)**\n", "").strip():
|
|
return "\n(No relevant previous game history to display)\n"
|
|
|
|
return game_history_str.strip()
|