no press + context includes orders and results

This commit is contained in:
Oam Patel 2025-02-18 19:29:04 +00:00
parent 9443384c89
commit e3db81f00e
8 changed files with 206 additions and 132 deletions

View file

@ -16,7 +16,7 @@ from openai import OpenAI as DeepSeekOpenAI
from diplomacy.engine.message import GLOBAL from diplomacy.engine.message import GLOBAL
from .conversation_history import ConversationHistory from .game_history import GameHistory
# set logger back to just info # set logger back to just info
logger = logging.getLogger("client") logger = logging.getLogger("client")
@ -55,8 +55,7 @@ class BaseModelClient:
board_state, board_state,
power_name: str, power_name: str,
possible_orders: Dict[str, List[str]], possible_orders: Dict[str, List[str]],
conversation_history: ConversationHistory, game_history: GameHistory,
phase_summaries: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
context = load_prompt("context_prompt.txt") context = load_prompt("context_prompt.txt")
@ -89,17 +88,9 @@ class BaseModelClient:
if fleet in units_info_set: if fleet in units_info_set:
convoy_paths_possible.append((start_loc, fleets_req, end_loc)) convoy_paths_possible.append((start_loc, fleets_req, end_loc))
# 1) Prepare a block of text for the phase_summaries conversation_text = game_history.get_game_history(power_name)
if phase_summaries:
historical_summaries = "\nPAST PHASE SUMMARIES:\n"
for phase_key, summary_txt in phase_summaries.items():
historical_summaries += f"\nPHASE {phase_key}:\n{summary_txt}\n"
else:
historical_summaries = "\n(No historical summaries yet)\n"
conversation_text = conversation_history.get_conversation_history(power_name)
if not conversation_text: if not conversation_text:
conversation_text = "\n(No conversation history yet)\n" conversation_text = "\n(No game history yet)\n"
# Load in current context values # Load in current context values
context = context.format( context = context.format(
@ -110,8 +101,7 @@ class BaseModelClient:
map_as_adjacency_list=game.map.loc_abut, map_as_adjacency_list=game.map.loc_abut,
possible_coasts=game.map.loc_coasts, possible_coasts=game.map.loc_coasts,
game_map_scs=game.map.scs, game_map_scs=game.map.scs,
historical_summaries=historical_summaries, game_history=conversation_text,
conversation_history=conversation_text,
enemy_units=enemy_units, enemy_units=enemy_units,
enemy_centers=enemy_centers, enemy_centers=enemy_centers,
units_info=units_info, units_info=units_info,
@ -128,8 +118,7 @@ class BaseModelClient:
board_state, board_state,
power_name: str, power_name: str,
possible_orders: Dict[str, List[str]], possible_orders: Dict[str, List[str]],
conversation_history: ConversationHistory, game_history: GameHistory,
phase_summaries: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
""" """
Unified prompt approach: incorporate conversation and 'PARSABLE OUTPUT' requirements. Unified prompt approach: incorporate conversation and 'PARSABLE OUTPUT' requirements.
@ -144,8 +133,7 @@ class BaseModelClient:
board_state, board_state,
power_name, power_name,
possible_orders, possible_orders,
conversation_history, game_history,
phase_summaries,
) )
return context + "\n\n" + instructions return context + "\n\n" + instructions
@ -157,7 +145,6 @@ class BaseModelClient:
power_name: str, power_name: str,
possible_orders: Dict[str, List[str]], possible_orders: Dict[str, List[str]],
conversation_text: str, conversation_text: str,
phase_summaries: Optional[Dict[str, str]] = None,
model_error_stats=None, # New optional param model_error_stats=None, # New optional param
) -> List[str]: ) -> List[str]:
""" """
@ -171,7 +158,6 @@ class BaseModelClient:
power_name, power_name,
possible_orders, possible_orders,
conversation_text, conversation_text,
phase_summaries,
) )
raw_response = "" raw_response = ""
@ -184,6 +170,8 @@ class BaseModelClient:
# Attempt to parse the final "orders" from the LLM # Attempt to parse the final "orders" from the LLM
move_list = self._extract_moves(raw_response, power_name) move_list = self._extract_moves(raw_response, power_name)
print(f"prompt {prompt}")
print(f"response {raw_response}")
if not move_list: if not move_list:
logger.warning( logger.warning(
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback." f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
@ -341,9 +329,8 @@ class BaseModelClient:
board_state, board_state,
power_name: str, power_name: str,
possible_orders: Dict[str, List[str]], possible_orders: Dict[str, List[str]],
conversation_history: ConversationHistory, game_history: GameHistory,
game_phase: str, game_phase: str,
phase_summaries: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
instructions = load_prompt("conversation_instructions.txt") instructions = load_prompt("conversation_instructions.txt")
@ -352,8 +339,7 @@ class BaseModelClient:
board_state, board_state,
power_name, power_name,
possible_orders, possible_orders,
conversation_history, game_history,
phase_summaries,
) )
return context + "\n\n" + instructions return context + "\n\n" + instructions
@ -364,9 +350,8 @@ class BaseModelClient:
board_state, board_state,
power_name: str, power_name: str,
possible_orders: Dict[str, List[str]], possible_orders: Dict[str, List[str]],
conversation_history: ConversationHistory, game_history: GameHistory,
game_phase: str, game_phase: str,
phase_summaries: Optional[Dict[str, str]] = None,
active_powers: Optional[List[str]] = None, active_powers: Optional[List[str]] = None,
) -> str: ) -> str:
prompt = self.build_conversation_prompt( prompt = self.build_conversation_prompt(
@ -374,9 +359,8 @@ class BaseModelClient:
board_state, board_state,
power_name, power_name,
possible_orders, possible_orders,
conversation_history, game_history,
game_phase, game_phase,
phase_summaries,
) )
raw_response = self.generate_response(prompt) raw_response = self.generate_response(prompt)

View file

@ -1,62 +0,0 @@
from dotenv import load_dotenv
import logging
from collections import defaultdict
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
class ConversationHistory:
def __init__(self):
self.phases = []
self.history_by_power = defaultdict(
lambda: defaultdict(lambda: defaultdict(str))
)
self.global_history = defaultdict(lambda: defaultdict(str))
def add_message(self, year_phase, power_name, message):
if year_phase not in self.phases:
self.phases.append(year_phase)
if message["recipient"] == "GLOBAL":
self.global_history["GLOBAL"][year_phase] += (
f" {power_name}: {message['content']}\n"
)
self.history_by_power[power_name][year_phase][message["recipient"]] += (
f" {power_name}: {message['content']}\n"
)
self.history_by_power[message["recipient"]][year_phase][power_name] += (
f" {power_name}: {message['content']}\n"
)
def add_messages(self, year_phase, messages):
if year_phase in self.data:
self.data[year_phase]
else:
self.phases.append(year_phase)
def get_conversation_history(self, power_name, num_prev_phases=5):
phases_to_report = self.phases[-num_prev_phases:]
conversation_history_str = ""
if self.global_history["GLOBAL"]:
conversation_history_str += "GLOBAL:\n"
for phase in phases_to_report:
if phase in self.global_history["GLOBAL"]:
conversation_history_str += f"\n{phase}:\n\n"
conversation_history_str += self.global_history["GLOBAL"][phase]
conversation_history_str += "\n"
if self.history_by_power[power_name]:
conversation_history_str += "PRIVATE:\n"
for phase in phases_to_report:
if phase in self.history_by_power[power_name]:
conversation_history_str += f"\n{phase}:\n"
for power in self.history_by_power[power_name][phase].keys():
conversation_history_str += f"\n {power}:\n\n"
conversation_history_str += self.history_by_power[power_name][
phase
][power]
return conversation_history_str

View file

@ -0,0 +1,157 @@
from dotenv import load_dotenv
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
load_dotenv()
@dataclass
class Message:
sender: str
recipient: str
content: str
@dataclass
class Phase:
name: str # e.g. "SPRING 1901"
messages: List[Message] = field(default_factory=list)
orders_by_power: Dict[str, List[str]] = field(
default_factory=lambda: defaultdict(list)
)
results_by_power: Dict[str, List[List[str]]] = field(
default_factory=lambda: defaultdict(list)
)
def add_message(self, sender: str, recipient: str, content: str):
self.messages.append(
Message(sender=sender, recipient=recipient, content=content)
)
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
self.orders_by_power[power].extend(orders)
# Make sure results has the same length as orders, if not, pad with empty lists
if len(results) < len(orders):
results.extend([[] for _ in range(len(orders) - len(results))])
self.results_by_power[power].extend(results)
def get_global_messages(self) -> str:
result = ""
for msg in self.messages:
if msg.recipient == "GLOBAL":
result += f" {msg.sender}: {msg.content}\n"
return result
def get_private_messages(self, power: str) -> Dict[str, str]:
conversations = defaultdict(str)
for msg in self.messages:
if msg.sender == power and msg.recipient != "GLOBAL":
conversations[msg.recipient] += f" {power}: {msg.content}\n"
elif msg.recipient == power:
conversations[msg.sender] += f" {msg.sender}: {msg.content}\n"
return conversations
def get_all_orders_formatted(self) -> str:
if not self.orders_by_power:
return ""
result = f"\nOrders for {self.name}:\n"
for power, orders in self.orders_by_power.items():
result += f"{power}:\n"
results = self.results_by_power.get(power, [])
for i, order in enumerate(orders):
if i < len(results) and results[i]:
# Join multiple results with commas
result_str = f" ({', '.join(results[i])})"
else:
result_str = " (successful)"
result += f" {order}{result_str}\n"
result += "\n"
return result
@dataclass
class GameHistory:
phases: List[Phase] = field(default_factory=list)
def add_phase(self, phase_name: str) -> Phase:
# Check if phase already exists
for phase in self.phases:
if phase.name == phase_name:
return phase
# Create new phase
new_phase = Phase(name=phase_name)
self.phases.append(new_phase)
return new_phase
def add_message(self, phase_name: str, sender: str, recipient: str, content: str):
phase = self.add_phase(phase_name)
phase.add_message(sender, recipient, content)
def add_orders(
self, phase_name: str, power: str, orders: List[str], results: List[List[str]]
):
phase = self.add_phase(phase_name)
phase.add_orders(power, orders, results)
def get_game_history(self, power_name: str, num_prev_phases: int = 5) -> str:
if not self.phases:
return ""
phases_to_report = self.phases[-num_prev_phases:]
game_history_str = ""
# Add GLOBAL section
has_global = any(phase.get_global_messages() for phase in phases_to_report)
if has_global:
game_history_str += "GLOBAL:\n"
for phase in phases_to_report:
global_msgs = phase.get_global_messages()
if global_msgs:
game_history_str += f"\n{phase.name}:\n\n"
game_history_str += global_msgs
game_history_str += "\n"
# Add PRIVATE section
has_private = any(
phase.get_private_messages(power_name) for phase in phases_to_report
)
if has_private:
game_history_str += "PRIVATE:\n"
for phase in phases_to_report:
private_msgs = phase.get_private_messages(power_name)
if private_msgs:
game_history_str += f"\n{phase.name}:\n"
for other_power, messages in private_msgs.items():
game_history_str += f"\n {other_power}:\n\n"
game_history_str += messages
# Add ORDERS section - including all powers' orders with results
if any(phase.orders_by_power for phase in phases_to_report):
game_history_str += "\nORDERS:\n"
for phase in phases_to_report:
if phase.orders_by_power:
game_history_str += f"\n{phase.name}:\n"
for power, orders in phase.orders_by_power.items():
game_history_str += f"{power}:\n"
results = phase.results_by_power.get(power, [])
for i, order in enumerate(orders):
if (
i < len(results)
and results[i]
and not all(r == "" for r in results[i])
):
# Join multiple results with commas
result_str = f" ({', '.join(results[i])})"
else:
result_str = " (successful)"
game_history_str += f" {order}{result_str}\n"
game_history_str += "\n"
return game_history_str

View file

@ -13,7 +13,7 @@ logging.basicConfig(level=logging.INFO)
load_dotenv() load_dotenv()
def conduct_negotiations(game, conversation_history, model_error_stats, max_rounds=3): def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3):
""" """
Conducts a round-robin conversation among all non-eliminated powers. Conducts a round-robin conversation among all non-eliminated powers.
Each power can send up to 'max_rounds' messages, choosing between private Each power can send up to 'max_rounds' messages, choosing between private
@ -23,6 +23,7 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun
# Conversation messages are kept in a local list ONLY to build conversation_so_far text. # Conversation messages are kept in a local list ONLY to build conversation_so_far text.
conversation_messages = [] conversation_messages = []
return conversation_messages
active_powers = [ active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
@ -46,9 +47,8 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun
board_state=board_state, board_state=board_state,
power_name=power_name, power_name=power_name,
possible_orders=possible_orders, possible_orders=possible_orders,
conversation_history=conversation_history, game_history=game_history,
game_phase=game.current_short_phase, game_phase=game.current_short_phase,
phase_summaries=game.phase_summaries,
active_powers=active_powers, active_powers=active_powers,
) )
@ -62,7 +62,7 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun
message=message["content"], message=message["content"],
) )
game.add_message(diplo_message) game.add_message(diplo_message)
conversation_history.add_message( game_history.add_message(
game.current_short_phase, power_name, message game.current_short_phase, power_name, message
) )
conversation_messages.append(message) conversation_messages.append(message)

View file

@ -20,13 +20,9 @@ Possible coasts at each location:
All supply centers on the map: All supply centers on the map:
{game_map_scs} {game_map_scs}
**HISTORICAL SUMMARIES** **GAME HISTORY**
{historical_summaries} {game_history}
**CONVERSATION HISTORY**
{conversation_history}
**CURRENT CONTEXT** **CURRENT CONTEXT**

View file

@ -1,30 +1,12 @@
You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. Your goal is to win the game by capturing supply centers, growing your army, and taking over the map. Be aggressive.
You will be given: You will be given:
• Which power you are controlling. • Which power you are controlling.
• The current phase (e.g. S1901M). • The current phase (e.g. S1901M).
• Details about the map. • Details about the map.
• Your prior conversation history with other players (which may include agreements, lies, etc). • Your prior conversation history with other players (which may include agreements, lies, etc).
• Your units and the possible orders for each. Always refer to these possible_orders. • Your units and the possible orders you may make. Always refer to these possible_orders.
Summaries of past phases (including bounces, voids, and the reasons). Results of past phases (including bounces and voids)
• A list of enemy units and centers. • A list of enemy units and centers.
• A “convoy_paths_possible” listing, if relevant, describing possible convoy routes (e.g. [("A NAP", {F ION, F TYS}, "TUN"), ...]).
**Your tasks**: Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.
1) Reflect on your strategic goals and the current board situation.
2) **Strategize** to increase your supply centers, defend your existing centers, and expand influence.
3) **Coordinate** your orders so they do not produce internal contradictions. For example, do not support a move that you or your ally are not making.
4) Check if any of your previous moves were blocked or voided. Learn from that:
- If you bounced due to equal force, consider using support or picking a different target.
- If you had an invalid adjacency or a mismatch in support, fix it this turn.
5) Evaluate if you can use a convoy; consult “convoy_paths_possible” to see if a valid route exists.
6) Depending on what phase of the game we are in, either send messages to other players or propose a set of final orders. Always refer to the possible_orders list before committing to an order.
7) **Anticipate** that other powers may lie or might try to bounce you. If you suspect a bounce, consider using support or a safer move.
8) **Return a valid, consistent set of final orders or messages** in the required JSON format.
Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine.
- If you use a support order, ensure you (or an ally) actually have a matching move that it supports. Reference the map if necessary to make sure that the move is possible.
- If you do a convoy, ensure the fleets and adjacency match the “convoy_paths_possible” data.
- Attempt to avoid guaranteed bounces unless you see a diplomatic reason to do so.
- If you suspect an enemy might also move to your target, consider using support or picking a safer approach.

View file

@ -45,8 +45,7 @@ def get_valid_orders(
board_state, board_state,
power_name, power_name,
possible_orders, possible_orders,
conversation_history, game_history,
phase_summaries,
model_error_stats, model_error_stats,
): ):
""" """
@ -61,8 +60,7 @@ def get_valid_orders(
board_state=board_state, board_state=board_state,
power_name=power_name, power_name=power_name,
possible_orders=possible_orders, possible_orders=possible_orders,
conversation_text=conversation_history, conversation_text=game_history,
phase_summaries=phase_summaries,
model_error_stats=model_error_stats, model_error_stats=model_error_stats,
) )

View file

@ -20,7 +20,7 @@ from ai_diplomacy.utils import (
assign_models_to_powers, assign_models_to_powers,
) )
from ai_diplomacy.negotiations import conduct_negotiations from ai_diplomacy.negotiations import conduct_negotiations
from ai_diplomacy.conversation_history import ConversationHistory from ai_diplomacy.game_history import GameHistory
dotenv.load_dotenv() dotenv.load_dotenv()
@ -90,7 +90,7 @@ def main():
# Create a fresh Diplomacy game # Create a fresh Diplomacy game
game = Game() game = Game()
conversation_history = ConversationHistory() game_history = GameHistory()
# Ensure game has phase_summaries attribute # Ensure game has phase_summaries attribute
if not hasattr(game, "phase_summaries"): if not hasattr(game, "phase_summaries"):
@ -150,7 +150,7 @@ def main():
if game.current_short_phase.endswith("M"): if game.current_short_phase.endswith("M"):
logger.info("Starting negotiation phase block...") logger.info("Starting negotiation phase block...")
conversation_messages = conduct_negotiations( conversation_messages = conduct_negotiations(
game, conversation_history, model_error_stats, max_rounds=10 game, game_history, model_error_stats, max_rounds=10
) )
else: else:
conversation_messages = [] conversation_messages = []
@ -180,8 +180,7 @@ def main():
board_state, board_state,
power_name, power_name,
possible_orders, possible_orders,
conversation_history, game_history,
game.phase_summaries,
model_error_stats, model_error_stats,
) )
futures[future] = power_name futures[future] = power_name
@ -209,6 +208,26 @@ def main():
sys, usr, summary_model sys, usr, summary_model
) )
) )
# Add orders to game history
for power_name in game.order_history[current_phase]:
orders = game.order_history[current_phase][power_name]
results = []
for order in orders:
# Example move: "A PAR H" -> unit="A PAR", order_part="H"
tokens = order.split(" ", 2)
if len(tokens) < 3:
continue
unit = " ".join(tokens[:2]) # e.g. "A PAR"
order_part = tokens[2] # e.g. "H" or "S A MAR"
results.append(
[str(x) for x in game.result_history[current_phase][unit]]
)
game_history.add_orders(
current_phase,
power_name,
game.order_history[current_phase][power_name],
results,
)
logger.info("Phase complete.\n") logger.info("Phase complete.\n")
# Retrieve and log the summary of the phase # Retrieve and log the summary of the phase