mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
add optional planning phase
This commit is contained in:
parent
98cb767d6c
commit
1f8ac5ae20
11 changed files with 222 additions and 77 deletions
|
|
@ -59,6 +59,7 @@ class BaseModelClient:
|
|||
power_name: str,
|
||||
possible_orders: Dict[str, List[str]],
|
||||
game_history: GameHistory,
|
||||
include_plans: bool = True
|
||||
) -> str:
|
||||
context = load_prompt("context_prompt.txt")
|
||||
|
||||
|
|
@ -91,7 +92,7 @@ class BaseModelClient:
|
|||
if fleet in units_info_set:
|
||||
convoy_paths_possible.append((start_loc, fleets_req, end_loc))
|
||||
|
||||
conversation_text = game_history.get_game_history(power_name)
|
||||
conversation_text = game_history.get_game_history(power_name, include_plans=include_plans)
|
||||
if not conversation_text:
|
||||
conversation_text = "\n(No game history yet)\n"
|
||||
|
||||
|
|
@ -325,6 +326,29 @@ class BaseModelClient:
|
|||
fallback.append(holds[0] if holds else orders_list[0])
|
||||
return fallback
|
||||
|
||||
def build_planning_prompt(
|
||||
self,
|
||||
game,
|
||||
board_state,
|
||||
power_name: str,
|
||||
possible_orders: Dict[str, List[str]],
|
||||
game_history: GameHistory,
|
||||
game_phase: str,
|
||||
) -> str:
|
||||
|
||||
instructions = load_prompt("planning_instructions.txt")
|
||||
|
||||
context = self.build_context_prompt(
|
||||
game,
|
||||
board_state,
|
||||
power_name,
|
||||
possible_orders,
|
||||
game_history,
|
||||
include_plans=False
|
||||
)
|
||||
|
||||
return context + "\n\n" + instructions
|
||||
|
||||
def build_conversation_prompt(
|
||||
self,
|
||||
game,
|
||||
|
|
@ -346,6 +370,29 @@ class BaseModelClient:
|
|||
|
||||
return context + "\n\n" + instructions
|
||||
|
||||
def get_planning_reply(
|
||||
self,
|
||||
game,
|
||||
board_state,
|
||||
power_name: str,
|
||||
possible_orders: Dict[str, List[str]],
|
||||
game_history: GameHistory,
|
||||
game_phase: str,
|
||||
active_powers: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
|
||||
prompt = self.build_planning_prompt(
|
||||
game,
|
||||
board_state,
|
||||
power_name,
|
||||
possible_orders,
|
||||
game_history,
|
||||
game_phase,
|
||||
)
|
||||
|
||||
raw_response = self.generate_response(prompt)
|
||||
return raw_response
|
||||
|
||||
def get_conversation_reply(
|
||||
self,
|
||||
game,
|
||||
|
|
@ -356,6 +403,7 @@ class BaseModelClient:
|
|||
game_phase: str,
|
||||
active_powers: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
|
||||
prompt = self.build_conversation_prompt(
|
||||
game,
|
||||
board_state,
|
||||
|
|
@ -368,6 +416,7 @@ class BaseModelClient:
|
|||
raw_response = self.generate_response(prompt)
|
||||
|
||||
messages = []
|
||||
import pdb; pdb.set_trace()
|
||||
if raw_response:
|
||||
try:
|
||||
# Find the JSON block between double curly braces
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class Message:
|
|||
@dataclass
|
||||
class Phase:
|
||||
name: str # e.g. "SPRING 1901"
|
||||
plans: Dict[str, str] = field(default_factory=dict)
|
||||
messages: List[Message] = field(default_factory=list)
|
||||
orders_by_power: Dict[str, List[str]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
|
|
@ -28,6 +29,9 @@ class Phase:
|
|||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
def add_plan(self, power_name: str, plan: str):
|
||||
self.plans[power_name] = plan
|
||||
|
||||
def add_message(self, sender: str, recipient: str, content: str):
|
||||
self.messages.append(
|
||||
Message(sender=sender, recipient=recipient, content=content)
|
||||
|
|
@ -90,6 +94,11 @@ class GameHistory:
|
|||
self.phases.append(new_phase)
|
||||
return new_phase
|
||||
|
||||
def add_plan(self, phase_name: str, power_name: str, plan: str):
|
||||
# get current phase
|
||||
phase = self.add_phase(phase_name)
|
||||
phase.add_plan(power_name, plan)
|
||||
|
||||
def add_message(self, phase_name: str, sender: str, recipient: str, content: str):
|
||||
phase = self.add_phase(phase_name)
|
||||
phase.add_message(sender, recipient, content)
|
||||
|
|
@ -100,7 +109,13 @@ class GameHistory:
|
|||
phase = self.add_phase(phase_name)
|
||||
phase.add_orders(power, orders, results)
|
||||
|
||||
def get_game_history(self, power_name: str, num_prev_phases: int = 5) -> str:
|
||||
def get_strategic_directives(self):
|
||||
# returns for last phase only if exists
|
||||
if not self.phases:
|
||||
return {}
|
||||
return self.phases[-1].plans
|
||||
|
||||
def get_game_history(self, power_name: str, include_plans: bool = True, num_prev_phases: int = 5) -> str:
|
||||
if not self.phases:
|
||||
return ""
|
||||
|
||||
|
|
@ -146,4 +161,10 @@ class GameHistory:
|
|||
|
||||
game_history_str += "-" * 50 + "\n" # Add separator between phases
|
||||
|
||||
# NOTE: only reports plan for the last phase (otherwise too much clutter)
|
||||
if include_plans and phases_to_report and (power_name in phases_to_report[-1].plans):
|
||||
game_history_str += f"\n{power_name} STRATEGIC DIRECTIVE:\n"
|
||||
game_history_str += "Here is a high-level directive you have planned out previously for this phase.\n"
|
||||
game_history_str += phases_to_report[-1].plans[power_name] + "\n"
|
||||
|
||||
return game_history_str
|
||||
|
|
|
|||
|
|
@ -22,9 +22,6 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
"""
|
||||
logger.info("Starting negotiation phase.")
|
||||
|
||||
# Conversation messages are kept in a local list ONLY to build conversation_so_far text.
|
||||
conversation_messages = []
|
||||
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
|
|
@ -32,7 +29,7 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
# We do up to 'max_rounds' single-message turns for each power
|
||||
for round_index in range(max_rounds):
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
max_workers=1
|
||||
) as executor:
|
||||
futures = {}
|
||||
for power_name in active_powers:
|
||||
|
|
@ -60,7 +57,6 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
power_name = futures[future]
|
||||
try:
|
||||
messages = future.result()
|
||||
if messages:
|
||||
for message in messages:
|
||||
|
|
@ -78,12 +74,11 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
|
|||
message["recipient"],
|
||||
message["content"],
|
||||
)
|
||||
conversation_messages.append(message)
|
||||
else:
|
||||
logger.debug(f"No valid messages returned for {power_name}.")
|
||||
model_error_stats[power_name]["conversation_errors"] += 1
|
||||
except Exception as exc:
|
||||
logger.error(f"LLM request failed for {power_name}: {exc}")
|
||||
# except Exception as exc:
|
||||
# logger.error(f"LLM request failed for {power_name}: {exc}")
|
||||
|
||||
logger.info("Negotiation phase complete.")
|
||||
return conversation_messages
|
||||
return game_history
|
||||
|
|
|
|||
67
ai_diplomacy/planning.py
Normal file
67
ai_diplomacy/planning.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from dotenv import load_dotenv
|
||||
import logging
|
||||
import concurrent.futures
|
||||
|
||||
from diplomacy.engine.message import Message, GLOBAL
|
||||
|
||||
from .clients import load_model_client
|
||||
from .utils import gather_possible_orders
|
||||
|
||||
logger = logging.getLogger("utils")
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def planning_phase(game, game_history, model_error_stats, max_rounds=3):
|
||||
"""
|
||||
Lets each power form a strategic directive for the upcoming phase.
|
||||
"""
|
||||
active_powers = [
|
||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
||||
]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(active_powers)
|
||||
) as executor:
|
||||
futures = {}
|
||||
for power_name in active_powers:
|
||||
model_id = game.power_model_map.get(power_name, "o3-mini")
|
||||
client = load_model_client(model_id)
|
||||
possible_orders = gather_possible_orders(game, power_name)
|
||||
if not possible_orders:
|
||||
logger.info(f"No orderable locations for {power_name}; skipping.")
|
||||
continue
|
||||
board_state = game.get_state()
|
||||
|
||||
future = executor.submit(
|
||||
client.get_planning_reply,
|
||||
game,
|
||||
board_state,
|
||||
power_name,
|
||||
possible_orders,
|
||||
game_history,
|
||||
game.current_short_phase,
|
||||
active_powers,
|
||||
)
|
||||
|
||||
futures[future] = power_name
|
||||
logger.debug(f"Submitted get_planning_reply task for {power_name}.")
|
||||
|
||||
logger.info("Waiting for planning replies...")
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
power_name = futures[future]
|
||||
try:
|
||||
reply = future.result()
|
||||
logger.info(f"Received planning reply from {power_name}.")
|
||||
if reply:
|
||||
game_history.add_plan(
|
||||
game.current_short_phase, power_name, reply
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in planning reply for {power_name}: {e}")
|
||||
model_error_stats[power_name] += 1
|
||||
|
||||
logger.info("Planning phase complete.")
|
||||
return game_history
|
||||
|
|
@ -39,6 +39,6 @@ PARSABLE OUTPUT:
|
|||
"content": "Let's form a secret alliance against Germany."
|
||||
}}
|
||||
|
||||
Reason step-by-step through your diplomatic position and the past phase summaries. Make sure you give thorough and clear reasoning. Send messages that tactically advance your position in the future. Be concrete in your messages. Your goal is to acquire supply centers and win the game. You can lie to achieve this goal. If there is no value in sending messages, then don't send any messages. Just respond with '.'.
|
||||
Reason step-by-step through your diplomatic position, the state of the game, and your past strategic directive. Make sure you give thorough and clear reasoning. Send messages that tactically advance your position in the future. Be concrete in your messages. Your goal is to acquire supply centers and win the game. You can lie to achieve this goal. If there is no value in sending messages, then don't send any messages. Just respond with '.'.
|
||||
|
||||
Remember you are in messaging mode right now. Do not give any orders. Each message you send must be in a separate bracketed JSON section. You may send multiple messages. Do not send them all as one JSON object. Do not repeat any JSON object in your output twice.
|
||||
14
ai_diplomacy/prompts/planning_instructions.txt
Normal file
14
ai_diplomacy/prompts/planning_instructions.txt
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
**CURRENT INSTRUCTIONS**
|
||||
|
||||
You are now to prepare a high-level strategic directive for this phase. Remember that your goal is to win via capturing supply centers. There are opportunity costs in this game.
|
||||
|
||||
You should reason through
|
||||
1. What would be ideal locations for you to capture during this phase (or in future phases),
|
||||
2. What needs to happen for those orders to be successful (e.g. not bounce),
|
||||
3. Are there strategic negotiations you could engage in with other powers to make those bounces less likely?,
|
||||
4. Will other powers be targeting any of your supply centers in this phase (or in future phases)?,
|
||||
5. Are there strategic negotiations you could engage in to reduce the chance they target your supply centers?
|
||||
|
||||
Develop a detailed directive that reasons through the supply centers you'd like to take and how you'll take them (including possible orders and possible alliances, concessions, lies, etc). Note that you can make deals, threats, or alliances with other powers. You can lie.
|
||||
|
||||
Note that the plan you outline below will be presented to yourself when you are asked to generate specific negotiation messages and specific orders. Be detailed, strategic, smart, and forward-thinking. Be wary of others trying to deceive you. Be specific as your response should help you produce messages for negotiation and orders later. Output a directive that is roughly a paragraph long.
|
||||
|
|
@ -5,8 +5,9 @@ You will be given:
|
|||
• The current phase (e.g. S1901M).
|
||||
• Details about the map.
|
||||
• Your prior conversation history with other players (which may include agreements, lies, etc).
|
||||
* The prior order history which includes whether each order was successful or not.
|
||||
• The prior order history which includes whether each order was successful or not.
|
||||
• A strategic plan that you have made if you are in the negotiations or orders phase.
|
||||
• Your units and the possible orders you may make. **Always refer to these possible_orders.**
|
||||
• A list of enemy units and centers.
|
||||
|
||||
Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.
|
||||
For the negotiations and orders phase, remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.
|
||||
|
|
@ -17,9 +17,9 @@ def assign_models_to_powers():
|
|||
|
||||
return {
|
||||
"FRANCE": "o3-mini",
|
||||
"GERMANY": "claude-3-5-sonnet-20241022",
|
||||
"GERMANY": "claude-3-5-sonnet-latest",
|
||||
"ENGLAND": "gemini-2.0-flash",
|
||||
"RUSSIA": "gemini-2.0-flash-lite-preview-02-05",
|
||||
"RUSSIA": "claude-3.7-sonnet-latest",
|
||||
"ITALY": "gpt-4o",
|
||||
"AUSTRIA": "gpt-4o-mini",
|
||||
"TURKEY": "claude-3-5-haiku-20241022",
|
||||
|
|
|
|||
57
lm_game.py
57
lm_game.py
|
|
@ -20,6 +20,7 @@ from ai_diplomacy.utils import (
|
|||
assign_models_to_powers,
|
||||
)
|
||||
from ai_diplomacy.negotiations import conduct_negotiations
|
||||
from ai_diplomacy.planning import planning_phase
|
||||
from ai_diplomacy.game_history import GameHistory
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
|
@ -32,14 +33,6 @@ logging.basicConfig(
|
|||
)
|
||||
|
||||
|
||||
def my_summary_callback(system_prompt, user_prompt, model_name):
|
||||
# Route to the desired model specified by the command-line argument
|
||||
client = load_model_client(model_name)
|
||||
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
|
||||
# Pseudo-code for generating a response:
|
||||
return client.generate_response(combined_prompt)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a Diplomacy game simulation with configurable parameters."
|
||||
|
|
@ -50,12 +43,6 @@ def parse_arguments():
|
|||
default=1901,
|
||||
help="Maximum year to simulate. The game will stop once this year is reached.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summary_model",
|
||||
type=str,
|
||||
default="o3-mini",
|
||||
help="Model name to use for generating phase summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_negotiation_rounds",
|
||||
type=int,
|
||||
|
|
@ -77,13 +64,17 @@ def parse_arguments():
|
|||
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--planning_phase",
|
||||
action="store_true",
|
||||
help="Enable the planning phase for each power to set strategic directives.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
max_year = args.max_year
|
||||
summary_model = args.summary_model
|
||||
|
||||
logger.info(
|
||||
"Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!"
|
||||
|
|
@ -154,15 +145,21 @@ def main():
|
|||
|
||||
# If it's a movement phase (e.g. ends with "M"), conduct negotiations
|
||||
if game.current_short_phase.endswith("M"):
|
||||
|
||||
if args.planning_phase:
|
||||
logger.info("Starting planning phase block...")
|
||||
game_history = planning_phase(
|
||||
game,
|
||||
game_history,
|
||||
model_error_stats,
|
||||
)
|
||||
logger.info("Starting negotiation phase block...")
|
||||
conversation_messages = conduct_negotiations(
|
||||
game_history = conduct_negotiations(
|
||||
game,
|
||||
game_history,
|
||||
model_error_stats,
|
||||
max_rounds=args.num_negotiation_rounds,
|
||||
)
|
||||
else:
|
||||
conversation_messages = []
|
||||
|
||||
# Gather orders from each power concurrently
|
||||
active_powers = [
|
||||
|
|
@ -213,12 +210,7 @@ def main():
|
|||
logger.error(f"LLM request failed for {p_name}: {exc}")
|
||||
|
||||
logger.info("Processing orders...\n")
|
||||
# Pass the summary model to the callback via a lambda function
|
||||
phase_data = game.process(
|
||||
phase_summary_callback=lambda sys, usr: my_summary_callback(
|
||||
sys, usr, summary_model
|
||||
)
|
||||
)
|
||||
game.process()
|
||||
# Add orders to game history
|
||||
for power_name in game.order_history[current_phase]:
|
||||
orders = game.order_history[current_phase][power_name]
|
||||
|
|
@ -241,16 +233,15 @@ def main():
|
|||
)
|
||||
logger.info("Phase complete.\n")
|
||||
|
||||
# Retrieve and log the summary of the phase
|
||||
summary_text = phase_data.summary or "(No summary found.)"
|
||||
border = "=" * 80
|
||||
logger.info(
|
||||
f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}"
|
||||
)
|
||||
|
||||
# Append the summary to the manifesto file
|
||||
# Append the strategic directives to the manifesto file
|
||||
strategic_directives = game_history.get_strategic_directives()
|
||||
if strategic_directives:
|
||||
out_str = f"Strategic directives for {current_phase}:\n"
|
||||
for power, directive in strategic_directives.items():
|
||||
out_str += f"{power}: {directive}\n\n"
|
||||
out_str += f"------------------------------------------\n"
|
||||
with open(manifesto_path, "a") as f:
|
||||
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
|
||||
f.write(out_str)
|
||||
|
||||
# Check if we've exceeded the max year
|
||||
year_str = current_phase[1:5]
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
6
run.sh
6
run.sh
|
|
@ -2,6 +2,6 @@
|
|||
|
||||
# note the summaries aren't actually used so the model doesn't matter here
|
||||
python lm_game.py \
|
||||
--max_year 1910 \
|
||||
--summary_model "gpt-4o-mini" \
|
||||
--num_negotiation_rounds 3
|
||||
--max_year 1905 \
|
||||
--num_negotiation_rounds 1 \
|
||||
--models "gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini"
|
||||
Loading…
Add table
Add a link
Reference in a new issue