mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
162 lines
7.5 KiB
Python
162 lines
7.5 KiB
Python
from dotenv import load_dotenv
|
|
import logging
|
|
import asyncio
|
|
from typing import Dict, TYPE_CHECKING
|
|
|
|
from diplomacy.engine.message import Message, GLOBAL
|
|
|
|
from .agent import DiplomacyAgent
|
|
from .utils import gather_possible_orders, normalize_recipient_name
|
|
|
|
if TYPE_CHECKING:
|
|
from .game_history import GameHistory
|
|
from diplomacy import Game
|
|
|
|
logger = logging.getLogger("negotiations")
|
|
logger.setLevel(logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
load_dotenv()
|
|
|
|
|
|
async def conduct_negotiations(
|
|
game: "Game",
|
|
agents: Dict[str, DiplomacyAgent],
|
|
game_history: "GameHistory",
|
|
model_error_stats: Dict[str, Dict[str, int]],
|
|
log_file_path: str,
|
|
max_rounds: int = 3,
|
|
):
|
|
"""
|
|
Conducts a round-robin conversation among all non-eliminated powers.
|
|
Each power can send up to 'max_rounds' messages, choosing between private
|
|
and global messages each turn. Uses asyncio for concurrent message generation.
|
|
"""
|
|
logger.info("Starting negotiation phase.")
|
|
|
|
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
|
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
|
|
|
logger.info(f"Active powers for negotiations: {active_powers}")
|
|
if eliminated_powers:
|
|
logger.info(f"Eliminated powers (skipped): {eliminated_powers}")
|
|
else:
|
|
logger.info("No eliminated powers yet.")
|
|
|
|
# We do up to 'max_rounds' single-message turns for each power
|
|
for round_index in range(max_rounds):
|
|
logger.info(f"Negotiation Round {round_index + 1}/{max_rounds}")
|
|
|
|
# Prepare tasks for asyncio.gather
|
|
tasks = []
|
|
power_names_for_tasks = []
|
|
|
|
for power_name in active_powers:
|
|
if power_name not in agents:
|
|
logger.warning(f"Agent for {power_name} not found in negotiations. Skipping.")
|
|
continue
|
|
agent = agents[power_name]
|
|
client = agent.client
|
|
|
|
possible_orders = gather_possible_orders(game, power_name)
|
|
if not possible_orders:
|
|
logger.info(f"No orderable locations for {power_name}; skipping message generation.")
|
|
continue
|
|
board_state = game.get_state()
|
|
|
|
# Append the coroutine to the tasks list
|
|
tasks.append(
|
|
client.get_conversation_reply(
|
|
game,
|
|
board_state,
|
|
power_name,
|
|
possible_orders,
|
|
game_history,
|
|
game.current_short_phase,
|
|
log_file_path=log_file_path,
|
|
active_powers=active_powers,
|
|
agent_goals=agent.goals,
|
|
agent_relationships=agent.relationships,
|
|
agent_private_diary_str=agent.format_private_diary_for_prompt(),
|
|
)
|
|
)
|
|
power_names_for_tasks.append(power_name)
|
|
logger.debug(f"Prepared get_conversation_reply task for {power_name}.")
|
|
|
|
# Run tasks concurrently if any were created
|
|
if tasks:
|
|
logger.debug(f"Running {len(tasks)} conversation tasks concurrently...")
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
else:
|
|
logger.debug("No conversation tasks to run for this round.")
|
|
results = []
|
|
|
|
# Process results
|
|
for i, result in enumerate(results):
|
|
power_name = power_names_for_tasks[i]
|
|
agent = agents[power_name] # Get agent again for journaling
|
|
model_name = agent.client.model_name # Get model name for stats
|
|
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
|
|
# Use model_name for stats key if possible
|
|
if model_name in model_error_stats:
|
|
model_error_stats[model_name]["conversation_errors"] += 1
|
|
else: # Fallback to power_name if model name not tracked (shouldn't happen)
|
|
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
|
model_error_stats[power_name]["conversation_errors"] += 1
|
|
messages = [] # Treat as no messages on error
|
|
elif result is None: # Handle case where client might return None on internal error
|
|
logger.warning(f"Received None instead of messages for {power_name}.")
|
|
messages = []
|
|
if model_name in model_error_stats:
|
|
model_error_stats[model_name]["conversation_errors"] += 1
|
|
else:
|
|
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
|
model_error_stats[power_name]["conversation_errors"] += 1
|
|
else:
|
|
messages = result # result is the list of message dicts
|
|
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
|
|
|
|
# Process the received messages (same logic as before)
|
|
if messages:
|
|
for message in messages:
|
|
# Validate message structure
|
|
if not isinstance(message, dict) or "content" not in message:
|
|
logger.warning(f"Invalid message format received from {power_name}: {message}. Skipping.")
|
|
continue
|
|
|
|
# Create an official message in the Diplomacy engine
|
|
# Determine recipient based on message type
|
|
if message.get("message_type") == "private":
|
|
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
|
|
if recipient not in game.powers and recipient != GLOBAL:
|
|
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
|
|
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
|
|
else: # Assume global if not private or type is missing
|
|
recipient = GLOBAL
|
|
|
|
diplo_message = Message(
|
|
phase=game.current_short_phase,
|
|
sender=power_name,
|
|
recipient=recipient, # Use determined recipient
|
|
message=message.get("content", ""), # Use .get for safety
|
|
time_sent=None, # Let the engine assign time
|
|
)
|
|
game.add_message(diplo_message)
|
|
# Also add to our custom history
|
|
game_history.add_message(
|
|
game.current_short_phase,
|
|
power_name,
|
|
recipient, # Use determined recipient here too
|
|
message.get("content", ""), # Use .get for safety
|
|
)
|
|
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
|
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")
|
|
logger.info(f"[{power_name} -> {recipient}] {message.get('content', '')[:100]}...")
|
|
else:
|
|
logger.debug(f"No valid messages returned or error occurred for {power_name}.")
|
|
# Error stats handled above based on result type
|
|
|
|
logger.info("Negotiation phase complete.")
|
|
return game_history
|