mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-23 16:59:10 +00:00
170 lines
6.3 KiB
Python
170 lines
6.3 KiB
Python
from dotenv import load_dotenv
|
|
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional
|
|
|
|
logger = logging.getLogger("utils")
|
|
logger.setLevel(logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
load_dotenv()
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
sender: str
|
|
recipient: str
|
|
content: str
|
|
|
|
|
|
@dataclass
|
|
class Phase:
|
|
name: str # e.g. "SPRING 1901"
|
|
plans: Dict[str, str] = field(default_factory=dict)
|
|
messages: List[Message] = field(default_factory=list)
|
|
orders_by_power: Dict[str, List[str]] = field(
|
|
default_factory=lambda: defaultdict(list)
|
|
)
|
|
results_by_power: Dict[str, List[List[str]]] = field(
|
|
default_factory=lambda: defaultdict(list)
|
|
)
|
|
|
|
def add_plan(self, power_name: str, plan: str):
|
|
self.plans[power_name] = plan
|
|
|
|
def add_message(self, sender: str, recipient: str, content: str):
|
|
self.messages.append(
|
|
Message(sender=sender, recipient=recipient, content=content)
|
|
)
|
|
|
|
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
|
self.orders_by_power[power].extend(orders)
|
|
# Make sure results has the same length as orders, if not, pad with empty lists
|
|
if len(results) < len(orders):
|
|
results.extend([[] for _ in range(len(orders) - len(results))])
|
|
self.results_by_power[power].extend(results)
|
|
|
|
def get_global_messages(self) -> str:
|
|
result = ""
|
|
for msg in self.messages:
|
|
if msg.recipient == "GLOBAL":
|
|
result += f" {msg.sender}: {msg.content}\n"
|
|
return result
|
|
|
|
def get_private_messages(self, power: str) -> Dict[str, str]:
|
|
conversations = defaultdict(str)
|
|
for msg in self.messages:
|
|
if msg.sender == power and msg.recipient != "GLOBAL":
|
|
conversations[msg.recipient] += f" {power}: {msg.content}\n"
|
|
elif msg.recipient == power:
|
|
conversations[msg.sender] += f" {msg.sender}: {msg.content}\n"
|
|
return conversations
|
|
|
|
def get_all_orders_formatted(self) -> str:
|
|
if not self.orders_by_power:
|
|
return ""
|
|
|
|
result = f"\nOrders for {self.name}:\n"
|
|
for power, orders in self.orders_by_power.items():
|
|
result += f"{power}:\n"
|
|
results = self.results_by_power.get(power, [])
|
|
for i, order in enumerate(orders):
|
|
if i < len(results) and results[i]:
|
|
# Join multiple results with commas
|
|
result_str = f" ({', '.join(results[i])})"
|
|
else:
|
|
result_str = " (successful)"
|
|
result += f" {order}{result_str}\n"
|
|
result += "\n"
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class GameHistory:
|
|
phases: List[Phase] = field(default_factory=list)
|
|
|
|
def add_phase(self, phase_name: str) -> Phase:
|
|
# Check if phase already exists
|
|
for phase in self.phases:
|
|
if phase.name == phase_name:
|
|
return phase
|
|
|
|
# Create new phase
|
|
new_phase = Phase(name=phase_name)
|
|
self.phases.append(new_phase)
|
|
return new_phase
|
|
|
|
def add_plan(self, phase_name: str, power_name: str, plan: str):
|
|
# get current phase
|
|
phase = self.add_phase(phase_name)
|
|
phase.add_plan(power_name, plan)
|
|
|
|
def add_message(self, phase_name: str, sender: str, recipient: str, content: str):
|
|
phase = self.add_phase(phase_name)
|
|
phase.add_message(sender, recipient, content)
|
|
|
|
def add_orders(
|
|
self, phase_name: str, power: str, orders: List[str], results: List[List[str]]
|
|
):
|
|
phase = self.add_phase(phase_name)
|
|
phase.add_orders(power, orders, results)
|
|
|
|
def get_strategic_directives(self):
|
|
# returns for last phase only if exists
|
|
if not self.phases:
|
|
return {}
|
|
return self.phases[-1].plans
|
|
|
|
def get_game_history(self, power_name: str, include_plans: bool = True, num_prev_phases: int = 5) -> str:
|
|
if not self.phases:
|
|
return ""
|
|
|
|
phases_to_report = self.phases[-num_prev_phases:]
|
|
game_history_str = ""
|
|
|
|
# Iterate through phases
|
|
for phase in phases_to_report:
|
|
game_history_str += f"\n{phase.name}:\n"
|
|
|
|
# Add GLOBAL section for this phase
|
|
global_msgs = phase.get_global_messages()
|
|
if global_msgs:
|
|
game_history_str += "\nGLOBAL:\n"
|
|
game_history_str += global_msgs
|
|
|
|
# Add PRIVATE section for this phase
|
|
private_msgs = phase.get_private_messages(power_name)
|
|
if private_msgs:
|
|
game_history_str += "\nPRIVATE:\n"
|
|
for other_power, messages in private_msgs.items():
|
|
game_history_str += f" {other_power}:\n\n"
|
|
game_history_str += messages + "\n"
|
|
|
|
# Add ORDERS section for this phase
|
|
if phase.orders_by_power:
|
|
game_history_str += "\nORDERS:\n"
|
|
for power, orders in phase.orders_by_power.items():
|
|
game_history_str += f"{power}:\n"
|
|
results = phase.results_by_power.get(power, [])
|
|
for i, order in enumerate(orders):
|
|
if (
|
|
i < len(results)
|
|
and results[i]
|
|
and not all(r == "" for r in results[i])
|
|
):
|
|
# Join multiple results with commas
|
|
result_str = f" ({', '.join(results[i])})"
|
|
else:
|
|
result_str = " (successful)"
|
|
game_history_str += f" {order}{result_str}\n"
|
|
game_history_str += "\n"
|
|
|
|
game_history_str += "-" * 50 + "\n" # Add separator between phases
|
|
|
|
# NOTE: only reports plan for the last phase (otherwise too much clutter)
|
|
if include_plans and phases_to_report and (power_name in phases_to_report[-1].plans):
|
|
game_history_str += f"\n{power_name} STRATEGIC DIRECTIVE:\n"
|
|
game_history_str += "Here is a high-level directive you have planned out previously for this phase.\n"
|
|
game_history_str += phases_to_report[-1].plans[power_name] + "\n"
|
|
|
|
return game_history_str
|