add optional planning phase

This commit is contained in:
Oam Patel 2025-02-27 02:10:48 +00:00
parent 98cb767d6c
commit 1f8ac5ae20
11 changed files with 222 additions and 77 deletions

View file

@ -59,6 +59,7 @@ class BaseModelClient:
power_name: str,
possible_orders: Dict[str, List[str]],
game_history: GameHistory,
include_plans: bool = True
) -> str:
context = load_prompt("context_prompt.txt")
@ -91,7 +92,7 @@ class BaseModelClient:
if fleet in units_info_set:
convoy_paths_possible.append((start_loc, fleets_req, end_loc))
conversation_text = game_history.get_game_history(power_name)
conversation_text = game_history.get_game_history(power_name, include_plans=include_plans)
if not conversation_text:
conversation_text = "\n(No game history yet)\n"
@ -325,6 +326,29 @@ class BaseModelClient:
fallback.append(holds[0] if holds else orders_list[0])
return fallback
def build_planning_prompt(
self,
game,
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history: GameHistory,
game_phase: str,
) -> str:
instructions = load_prompt("planning_instructions.txt")
context = self.build_context_prompt(
game,
board_state,
power_name,
possible_orders,
game_history,
include_plans=False
)
return context + "\n\n" + instructions
def build_conversation_prompt(
self,
game,
@ -346,6 +370,29 @@ class BaseModelClient:
return context + "\n\n" + instructions
def get_planning_reply(
self,
game,
board_state,
power_name: str,
possible_orders: Dict[str, List[str]],
game_history: GameHistory,
game_phase: str,
active_powers: Optional[List[str]] = None,
) -> str:
prompt = self.build_planning_prompt(
game,
board_state,
power_name,
possible_orders,
game_history,
game_phase,
)
raw_response = self.generate_response(prompt)
return raw_response
def get_conversation_reply(
self,
game,
@ -356,6 +403,7 @@ class BaseModelClient:
game_phase: str,
active_powers: Optional[List[str]] = None,
) -> str:
prompt = self.build_conversation_prompt(
game,
board_state,
@ -368,6 +416,7 @@ class BaseModelClient:
raw_response = self.generate_response(prompt)
messages = []
import pdb; pdb.set_trace()
if raw_response:
try:
# Find the JSON block between double curly braces

View file

@ -20,6 +20,7 @@ class Message:
@dataclass
class Phase:
name: str # e.g. "SPRING 1901"
plans: Dict[str, str] = field(default_factory=dict)
messages: List[Message] = field(default_factory=list)
orders_by_power: Dict[str, List[str]] = field(
default_factory=lambda: defaultdict(list)
@ -28,6 +29,9 @@ class Phase:
default_factory=lambda: defaultdict(list)
)
def add_plan(self, power_name: str, plan: str):
self.plans[power_name] = plan
def add_message(self, sender: str, recipient: str, content: str):
self.messages.append(
Message(sender=sender, recipient=recipient, content=content)
@ -90,6 +94,11 @@ class GameHistory:
self.phases.append(new_phase)
return new_phase
def add_plan(self, phase_name: str, power_name: str, plan: str):
# get current phase
phase = self.add_phase(phase_name)
phase.add_plan(power_name, plan)
def add_message(self, phase_name: str, sender: str, recipient: str, content: str):
phase = self.add_phase(phase_name)
phase.add_message(sender, recipient, content)
@ -100,7 +109,13 @@ class GameHistory:
phase = self.add_phase(phase_name)
phase.add_orders(power, orders, results)
def get_game_history(self, power_name: str, num_prev_phases: int = 5) -> str:
def get_strategic_directives(self):
# returns for last phase only if exists
if not self.phases:
return {}
return self.phases[-1].plans
def get_game_history(self, power_name: str, include_plans: bool = True, num_prev_phases: int = 5) -> str:
if not self.phases:
return ""
@ -146,4 +161,10 @@ class GameHistory:
game_history_str += "-" * 50 + "\n" # Add separator between phases
# NOTE: only reports plan for the last phase (otherwise too much clutter)
if include_plans and phases_to_report and (power_name in phases_to_report[-1].plans):
game_history_str += f"\n{power_name} STRATEGIC DIRECTIVE:\n"
game_history_str += "Here is a high-level directive you have planned out previously for this phase.\n"
game_history_str += phases_to_report[-1].plans[power_name] + "\n"
return game_history_str

View file

@ -22,9 +22,6 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
"""
logger.info("Starting negotiation phase.")
# Conversation messages are kept in a local list ONLY to build conversation_so_far text.
conversation_messages = []
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
@ -32,7 +29,7 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
# We do up to 'max_rounds' single-message turns for each power
for round_index in range(max_rounds):
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(active_powers)
max_workers=1
) as executor:
futures = {}
for power_name in active_powers:
@ -60,7 +57,6 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
for future in concurrent.futures.as_completed(futures):
power_name = futures[future]
try:
messages = future.result()
if messages:
for message in messages:
@ -78,12 +74,11 @@ def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
message["recipient"],
message["content"],
)
conversation_messages.append(message)
else:
logger.debug(f"No valid messages returned for {power_name}.")
model_error_stats[power_name]["conversation_errors"] += 1
except Exception as exc:
logger.error(f"LLM request failed for {power_name}: {exc}")
# except Exception as exc:
# logger.error(f"LLM request failed for {power_name}: {exc}")
logger.info("Negotiation phase complete.")
return conversation_messages
return game_history

67
ai_diplomacy/planning.py Normal file
View file

@ -0,0 +1,67 @@
from dotenv import load_dotenv
import logging
import concurrent.futures
from diplomacy.engine.message import Message, GLOBAL
from .clients import load_model_client
from .utils import gather_possible_orders
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
def planning_phase(game, game_history, model_error_stats, max_rounds=3):
"""
Lets each power form a strategic directive for the upcoming phase.
"""
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(active_powers)
) as executor:
futures = {}
for power_name in active_powers:
model_id = game.power_model_map.get(power_name, "o3-mini")
client = load_model_client(model_id)
possible_orders = gather_possible_orders(game, power_name)
if not possible_orders:
logger.info(f"No orderable locations for {power_name}; skipping.")
continue
board_state = game.get_state()
future = executor.submit(
client.get_planning_reply,
game,
board_state,
power_name,
possible_orders,
game_history,
game.current_short_phase,
active_powers,
)
futures[future] = power_name
logger.debug(f"Submitted get_planning_reply task for {power_name}.")
logger.info("Waiting for planning replies...")
for future in concurrent.futures.as_completed(futures):
power_name = futures[future]
try:
reply = future.result()
logger.info(f"Received planning reply from {power_name}.")
if reply:
game_history.add_plan(
game.current_short_phase, power_name, reply
)
except Exception as e:
logger.error(f"Error in planning reply for {power_name}: {e}")
model_error_stats[power_name] += 1
logger.info("Planning phase complete.")
return game_history

View file

@ -39,6 +39,6 @@ PARSABLE OUTPUT:
"content": "Let's form a secret alliance against Germany."
}}
Reason step-by-step through your diplomatic position and the past phase summaries. Make sure you give thorough and clear reasoning. Send messages that tactically advance your position in the future. Be concrete in your messages. Your goal is to acquire supply centers and win the game. You can lie to achieve this goal. If there is no value in sending messages, then don't send any messages. Just respond with '.'.
Reason step-by-step through your diplomatic position, the state of the game, and your past strategic directive. Make sure you give thorough and clear reasoning. Send messages that tactically advance your position in the future. Be concrete in your messages. Your goal is to acquire supply centers and win the game. You can lie to achieve this goal. If there is no value in sending messages, then don't send any messages. Just respond with '.'.
Remember you are in messaging mode right now. Do not give any orders. Each message you send must be in a separate bracketed JSON section. You may send multiple messages. Do not send them all as one JSON object. Do not repeat any JSON object in your output twice.

View file

@ -0,0 +1,14 @@
**CURRENT INSTRUCTIONS**
You are now to prepare a high-level strategic directive for this phase. Remember that your goal is to win via capturing supply centers. There are opportunity costs in this game.
You should reason through
1. What would be ideal locations for you to capture during this phase (or in future phases),
2. What needs to happen for those orders to be successful (e.g. not bounce),
3. Are there strategic negotiations you could engage in with other powers to make those bounces less likely?,
4. Will other powers be targeting any of your supply centers in this phase (or in future phases)?,
5. Are there strategic negotiations you could engage in to reduce the chance they target your supply centers?
Develop a detailed directive that reasons through the supply centers you'd like to take and how you'll take them (including possible orders and possible alliances, concessions, lies, etc). Note that you can make deals, threats, or alliances with other powers. You can lie.
Note that the plan you outline below will be presented to yourself when you are asked to generate specific negotiation messages and specific orders. Be detailed, strategic, smart, and forward-thinking. Be wary of others trying to deceive you. Be specific as your response should help you produce messages for negotiation and orders later. Output a directive that is roughly a paragraph long.

View file

@ -5,8 +5,9 @@ You will be given:
• The current phase (e.g. S1901M).
• Details about the map.
• Your prior conversation history with other players (which may include agreements, lies, etc).
* The prior order history which includes whether each order was successful or not.
• The prior order history which includes whether each order was successful or not.
• A strategic plan that you have made if you are in the negotiations or orders phase.
• Your units and the possible orders you may make. **Always refer to these possible_orders.**
• A list of enemy units and centers.
Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.
For the negotiations and orders phase, remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.

View file

@ -17,9 +17,9 @@ def assign_models_to_powers():
return {
"FRANCE": "o3-mini",
"GERMANY": "claude-3-5-sonnet-20241022",
"GERMANY": "claude-3-5-sonnet-latest",
"ENGLAND": "gemini-2.0-flash",
"RUSSIA": "gemini-2.0-flash-lite-preview-02-05",
"RUSSIA": "claude-3.7-sonnet-latest",
"ITALY": "gpt-4o",
"AUSTRIA": "gpt-4o-mini",
"TURKEY": "claude-3-5-haiku-20241022",

View file

@ -20,6 +20,7 @@ from ai_diplomacy.utils import (
assign_models_to_powers,
)
from ai_diplomacy.negotiations import conduct_negotiations
from ai_diplomacy.planning import planning_phase
from ai_diplomacy.game_history import GameHistory
dotenv.load_dotenv()
@ -32,14 +33,6 @@ logging.basicConfig(
)
def my_summary_callback(system_prompt, user_prompt, model_name):
# Route to the desired model specified by the command-line argument
client = load_model_client(model_name)
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
# Pseudo-code for generating a response:
return client.generate_response(combined_prompt)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a Diplomacy game simulation with configurable parameters."
@ -50,12 +43,6 @@ def parse_arguments():
default=1901,
help="Maximum year to simulate. The game will stop once this year is reached.",
)
parser.add_argument(
"--summary_model",
type=str,
default="o3-mini",
help="Model name to use for generating phase summaries.",
)
parser.add_argument(
"--num_negotiation_rounds",
type=int,
@ -77,13 +64,17 @@ def parse_arguments():
"The order is: AUSTRIA, ENGLAND, FRANCE, GERMANY, ITALY, RUSSIA, TURKEY."
),
)
parser.add_argument(
"--planning_phase",
action="store_true",
help="Enable the planning phase for each power to set strategic directives.",
)
return parser.parse_args()
def main():
args = parse_arguments()
max_year = args.max_year
summary_model = args.summary_model
logger.info(
"Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!"
@ -154,15 +145,21 @@ def main():
# If it's a movement phase (e.g. ends with "M"), conduct negotiations
if game.current_short_phase.endswith("M"):
if args.planning_phase:
logger.info("Starting planning phase block...")
game_history = planning_phase(
game,
game_history,
model_error_stats,
)
logger.info("Starting negotiation phase block...")
conversation_messages = conduct_negotiations(
game_history = conduct_negotiations(
game,
game_history,
model_error_stats,
max_rounds=args.num_negotiation_rounds,
)
else:
conversation_messages = []
# Gather orders from each power concurrently
active_powers = [
@ -213,12 +210,7 @@ def main():
logger.error(f"LLM request failed for {p_name}: {exc}")
logger.info("Processing orders...\n")
# Pass the summary model to the callback via a lambda function
phase_data = game.process(
phase_summary_callback=lambda sys, usr: my_summary_callback(
sys, usr, summary_model
)
)
game.process()
# Add orders to game history
for power_name in game.order_history[current_phase]:
orders = game.order_history[current_phase][power_name]
@ -241,16 +233,15 @@ def main():
)
logger.info("Phase complete.\n")
# Retrieve and log the summary of the phase
summary_text = phase_data.summary or "(No summary found.)"
border = "=" * 80
logger.info(
f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}"
)
# Append the summary to the manifesto file
# Append the strategic directives to the manifesto file
strategic_directives = game_history.get_strategic_directives()
if strategic_directives:
out_str = f"Strategic directives for {current_phase}:\n"
for power, directive in strategic_directives.items():
out_str += f"{power}: {directive}\n\n"
out_str += f"------------------------------------------\n"
with open(manifesto_path, "a") as f:
f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n")
f.write(out_str)
# Check if we've exceeded the max year
year_str = current_phase[1:5]

File diff suppressed because one or more lines are too long

6
run.sh
View file

@ -2,6 +2,6 @@
# note the summaries aren't actually used so the model doesn't matter here
python lm_game.py \
--max_year 1910 \
--summary_model "gpt-4o-mini" \
--num_negotiation_rounds 3
--max_year 1905 \
--num_negotiation_rounds 1 \
--models "gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini, gpt-4o-mini"