AI_Diplomacy/lm_game.py
2025-02-10 18:51:59 -08:00

304 lines
No EOL
12 KiB
Python

import logging
import time
import dotenv
import os
import re
import json
# Additional import for error stats
from collections import defaultdict
# Suppress Gemini/PaLM gRPC warnings
os.environ['GRPC_PYTHON_LOG_LEVEL'] = '40' # ERROR level only
import google.generativeai as genai # Import after setting log level
from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format
# Added import: we'll create and add standard Diplomacy messages
from diplomacy.engine.message import Message, GLOBAL
# For concurrency:
import concurrent.futures
from lm_service_versus import load_model_client, assign_models_to_powers
dotenv.load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
datefmt="%H:%M:%S"
)
def gather_possible_orders(game, power_name):
"""
Returns a dictionary mapping each orderable location to the list of valid orders.
"""
orderable_locs = game.get_orderable_locations(power_name)
all_possible = game.get_all_possible_orders()
result = {}
for loc in orderable_locs:
result[loc] = all_possible.get(loc, [])
return result
def conduct_negotiations(game, model_error_stats, max_rounds=10):
"""
Conducts a round-robin conversation among all non-eliminated powers.
Each power can send up to 'max_rounds' messages, choosing between private
and global messages each turn.
"""
logger.info("Starting negotiation phase.")
# Conversation messages are kept in a local list ONLY to build conversation_so_far text.
conversation_messages = []
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
# We do up to 'max_rounds' single-message turns for each power
for round_index in range(max_rounds):
for power_name in active_powers:
# Build the conversation context from all messages the power can see
visible_messages = []
for msg in conversation_messages:
# Include if message is global or if power is sender/recipient
if msg['recipient'] == GLOBAL or msg['sender'] == power_name or msg['recipient'] == power_name:
visible_messages.append(
f"{msg['sender']} to {msg['recipient']}: {msg['content']}"
)
conversation_so_far = "\n".join(visible_messages)
# Add few-shot example for message format
few_shot_example = """
Example response formats:
1. For a global message:
{
"message_type": "global",
"content": "I propose we all work together against Turkey."
}
2. For a private message:
{
"message_type": "private",
"recipient": "FRANCE",
"content": "Let's form a secret alliance against Germany."
}
Note: There are a total of 10 messages in this negotiation phase. This is message #{} out of 10. By the end, you should have coordinated moves effectively to avoid being blocked or bounced with others.
If you have your plan already figured out, you can just send a public '.' to indicate you're ready to move on.
"""
# Ask the LLM for a single reply
client = load_model_client(game.power_model_map.get(power_name, "o3-mini"))
new_message = client.get_conversation_reply(
power_name=power_name,
conversation_so_far=conversation_so_far + "\n" + few_shot_example,
game_phase=game.current_short_phase,
phase_summaries=game.phase_summaries
)
if new_message:
try:
# Parse the JSON response
# Find the JSON block between curly braces
json_match = re.search(r'\{[^}]+\}', new_message)
if json_match:
message_data = json.loads(json_match.group(0))
# Extract message details
message_type = message_data.get('message_type', 'global')
content = message_data.get('content', '').strip()
recipient = message_data.get('recipient', GLOBAL)
# Validate recipient if private message
if message_type == 'private' and recipient not in active_powers:
logger.warning(f"Invalid recipient {recipient} for private message, defaulting to GLOBAL")
recipient = GLOBAL
# For private messages, ensure recipient is specified
if message_type == 'private' and recipient == GLOBAL:
logger.warning("Private message without recipient specified, defaulting to GLOBAL")
# Log for debugging
logger.info(f"Power {power_name} sends {message_type} message to {recipient}")
# Keep local record for building future conversation context
conversation_messages.append({
"sender": power_name,
"recipient": recipient,
"content": content
})
# Create an official message in the Diplomacy engine
diplo_message = Message(
phase=game.current_short_phase,
sender=power_name,
recipient=recipient,
message=content
)
game.add_message(diplo_message)
except (json.JSONDecodeError, AttributeError) as e:
logger.error(f"Failed to parse message from {power_name}: {e}")
# Increment conversation parse error
model_id = game.power_model_map.get(power_name, "unknown")
model_error_stats[model_id]["conversation_errors"] += 1
continue
logger.info("Negotiation phase complete.")
return conversation_messages
def my_summary_callback(system_prompt, user_prompt):
# e.g., route to your desired model:
client = load_model_client("o3-mini")
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
# Pseudo-code for generating a response:
return client.generate_response(combined_prompt)
def main():
logger.info("Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!")
start_whole = time.time()
from collections import defaultdict
model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0})
# Create a fresh Diplomacy game
game = Game()
# Ensure game has phase_summaries = {}
if not hasattr(game, 'phase_summaries'):
game.phase_summaries = {}
# For storing results in a unique subfolder
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
result_folder = f"./results/{timestamp_str}"
if not os.path.exists(result_folder):
os.makedirs(result_folder)
# Manifesto and game file paths
manifesto_path = f"{result_folder}/game_manifesto.txt"
game_file_path = f"{result_folder}/lmvsgame.json"
stats_file_path = f"{result_folder}/error_stats.json"
game.power_model_map = assign_models_to_powers()
max_year = 1901
while not game.is_game_done:
phase_start = time.time()
current_phase = game.get_current_phase()
logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)")
# DEBUG: Print the short phase to confirm
logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'")
# Prevent unbounded sim
year_str = current_phase[1:5]
year_int = int(year_str)
if year_int > max_year:
logger.info(f"Reached year {year_int}, stopping the test game early.")
break
# Use endswith("M") for movement phases (like F1901M, S1902M)
if game.current_short_phase.endswith("M"):
logger.info("Starting negotiation phase block...")
conversation_messages = conduct_negotiations(game, model_error_stats, max_rounds=10)
else:
# If we have no conversation_messages in phases that are not Movement (e.g. Retreat/Build)
conversation_messages = []
conversation_text_for_orders = "\n".join([
f"{msg['sender']} to {msg['recipient']}: {msg['content']}"
for msg in conversation_messages
])
# Gather orders from each power concurrently
active_powers = [
(p_name, p_obj) for p_name, p_obj in game.powers.items()
if not p_obj.is_eliminated()
]
# Then proceed with concurrent order generation
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
futures = {}
for power_name, _ in active_powers:
model_id = game.power_model_map.get(power_name, "o3-mini")
client = load_model_client(model_id)
possible_orders = gather_possible_orders(game, power_name)
if not possible_orders:
logger.info(f"No orderable locations for {power_name}; skipping.")
continue
board_state = game.get_state()
future = executor.submit(
client.get_orders,
board_state,
power_name,
possible_orders,
conversation_text_for_orders,
game.phase_summaries,
model_error_stats # pass our stats
)
futures[future] = power_name
logger.debug(f"Submitted get_orders task for power {power_name}.")
for future in concurrent.futures.as_completed(futures):
p_name = futures[future]
try:
orders = future.result()
logger.debug(f"Orders for {p_name}: {orders}")
if orders:
game.set_orders(p_name, orders)
logger.debug(f"Set orders for {p_name} in {game.current_short_phase}: {orders}")
else:
logger.debug(f"No orders returned for {p_name}.")
except Exception as exc:
logger.error(f"LLM request failed for {p_name}: {exc}")
logger.info("Processing orders...\n")
phase_data = game.process(phase_summary_callback=my_summary_callback)
logger.info("Phase complete.\n")
# Retrieve the last-processed phase data from the game's history
summary_text = phase_data.summary or "(No summary found.)"
# Print in pretty ASCII format
border = "=" * 80
logger.info(f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}")
# Write to unique game_manifesto in the timestamped folder
with open(manifesto_path, "a") as f:
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
# End-of-loop checks
year_str = current_phase[1:5]
year_int = int(year_str)
if year_int > max_year:
logger.info(f"Reached year {year_int}, stopping the test game early.")
break
# Save final result
duration = time.time() - start_whole
logger.info(f"Game ended after {duration:.2f}s. Saving to final JSON...")
# Save final result to the unique subfolder
output_path = game_file_path
if not os.path.exists(output_path):
to_saved_game_format(game, output_path=output_path)
else:
logger.info("Game file already exists, saving with unique filename.")
output_path = f'{output_path}_{time.strftime("%Y%m%d_%H%M%S")}.json'
to_saved_game_format(game, output_path=output_path)
# Dump our error stats to JSON
import json
with open(stats_file_path, "w") as stats_f:
json.dump(model_error_stats, stats_f, indent=2)
logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}")
logger.info("Done.")
if __name__ == "__main__":
main()