From e3db81f00e8c7f05ce4d6db4eaa4e6aa886288a1 Mon Sep 17 00:00:00 2001 From: Oam Patel Date: Tue, 18 Feb 2025 19:29:04 +0000 Subject: [PATCH] no press + context includes orders and results --- ai_diplomacy/clients.py | 42 ++----- ai_diplomacy/conversation_history.py | 62 ---------- ai_diplomacy/game_history.py | 157 ++++++++++++++++++++++++ ai_diplomacy/negotiations.py | 8 +- ai_diplomacy/prompts/context_prompt.txt | 8 +- ai_diplomacy/prompts/system_prompt.txt | 26 +--- ai_diplomacy/utils.py | 6 +- lm_game.py | 29 ++++- 8 files changed, 206 insertions(+), 132 deletions(-) delete mode 100644 ai_diplomacy/conversation_history.py create mode 100644 ai_diplomacy/game_history.py diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index 71937cf..34c8065 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -16,7 +16,7 @@ from openai import OpenAI as DeepSeekOpenAI from diplomacy.engine.message import GLOBAL -from .conversation_history import ConversationHistory +from .game_history import GameHistory # set logger back to just info logger = logging.getLogger("client") @@ -55,8 +55,7 @@ class BaseModelClient: board_state, power_name: str, possible_orders: Dict[str, List[str]], - conversation_history: ConversationHistory, - phase_summaries: Optional[Dict[str, str]] = None, + game_history: GameHistory, ) -> str: context = load_prompt("context_prompt.txt") @@ -89,17 +88,9 @@ class BaseModelClient: if fleet in units_info_set: convoy_paths_possible.append((start_loc, fleets_req, end_loc)) - # 1) Prepare a block of text for the phase_summaries - if phase_summaries: - historical_summaries = "\nPAST PHASE SUMMARIES:\n" - for phase_key, summary_txt in phase_summaries.items(): - historical_summaries += f"\nPHASE {phase_key}:\n{summary_txt}\n" - else: - historical_summaries = "\n(No historical summaries yet)\n" - - conversation_text = conversation_history.get_conversation_history(power_name) + conversation_text = game_history.get_game_history(power_name) if not conversation_text: - conversation_text = "\n(No conversation history yet)\n" + conversation_text = "\n(No game history yet)\n" # Load in current context values context = context.format( @@ -110,8 +101,7 @@ class BaseModelClient: map_as_adjacency_list=game.map.loc_abut, possible_coasts=game.map.loc_coasts, game_map_scs=game.map.scs, - historical_summaries=historical_summaries, - conversation_history=conversation_text, + game_history=conversation_text, enemy_units=enemy_units, enemy_centers=enemy_centers, units_info=units_info, @@ -128,8 +118,7 @@ class BaseModelClient: board_state, power_name: str, possible_orders: Dict[str, List[str]], - conversation_history: ConversationHistory, - phase_summaries: Optional[Dict[str, str]] = None, + game_history: GameHistory, ) -> str: """ Unified prompt approach: incorporate conversation and 'PARSABLE OUTPUT' requirements. @@ -144,8 +133,7 @@ class BaseModelClient: board_state, power_name, possible_orders, - conversation_history, - phase_summaries, + game_history, ) return context + "\n\n" + instructions @@ -157,7 +145,6 @@ class BaseModelClient: power_name: str, possible_orders: Dict[str, List[str]], conversation_text: str, - phase_summaries: Optional[Dict[str, str]] = None, model_error_stats=None, # New optional param ) -> List[str]: """ @@ -171,7 +158,6 @@ class BaseModelClient: power_name, possible_orders, conversation_text, - phase_summaries, ) raw_response = "" @@ -184,6 +170,8 @@ class BaseModelClient: # Attempt to parse the final "orders" from the LLM move_list = self._extract_moves(raw_response, power_name) + print(f"prompt {prompt}") + print(f"response {raw_response}") if not move_list: logger.warning( f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback." @@ -341,9 +329,8 @@ class BaseModelClient: board_state, power_name: str, possible_orders: Dict[str, List[str]], - conversation_history: ConversationHistory, + game_history: GameHistory, game_phase: str, - phase_summaries: Optional[Dict[str, str]] = None, ) -> str: instructions = load_prompt("conversation_instructions.txt") @@ -352,8 +339,7 @@ class BaseModelClient: board_state, power_name, possible_orders, - conversation_history, - phase_summaries, + game_history, ) return context + "\n\n" + instructions @@ -364,9 +350,8 @@ class BaseModelClient: board_state, power_name: str, possible_orders: Dict[str, List[str]], - conversation_history: ConversationHistory, + game_history: GameHistory, game_phase: str, - phase_summaries: Optional[Dict[str, str]] = None, active_powers: Optional[List[str]] = None, ) -> str: prompt = self.build_conversation_prompt( @@ -374,9 +359,8 @@ class BaseModelClient: board_state, power_name, possible_orders, - conversation_history, + game_history, game_phase, - phase_summaries, ) raw_response = self.generate_response(prompt) diff --git a/ai_diplomacy/conversation_history.py b/ai_diplomacy/conversation_history.py deleted file mode 100644 index 5f90439..0000000 --- a/ai_diplomacy/conversation_history.py +++ /dev/null @@ -1,62 +0,0 @@ -from dotenv import load_dotenv -import logging -from collections import defaultdict - -logger = logging.getLogger("utils") -logger.setLevel(logging.INFO) -logging.basicConfig(level=logging.INFO) - -load_dotenv() - - -class ConversationHistory: - def __init__(self): - self.phases = [] - self.history_by_power = defaultdict( - lambda: defaultdict(lambda: defaultdict(str)) - ) - self.global_history = defaultdict(lambda: defaultdict(str)) - - def add_message(self, year_phase, power_name, message): - if year_phase not in self.phases: - self.phases.append(year_phase) - - if message["recipient"] == "GLOBAL": - self.global_history["GLOBAL"][year_phase] += ( - f" {power_name}: {message['content']}\n" - ) - self.history_by_power[power_name][year_phase][message["recipient"]] += ( - f" {power_name}: {message['content']}\n" - ) - self.history_by_power[message["recipient"]][year_phase][power_name] += ( - f" {power_name}: {message['content']}\n" - ) - - def add_messages(self, year_phase, messages): - if year_phase in self.data: - self.data[year_phase] - else: - self.phases.append(year_phase) - - def get_conversation_history(self, power_name, num_prev_phases=5): - phases_to_report = self.phases[-num_prev_phases:] - conversation_history_str = "" - if self.global_history["GLOBAL"]: - conversation_history_str += "GLOBAL:\n" - for phase in phases_to_report: - if phase in self.global_history["GLOBAL"]: - conversation_history_str += f"\n{phase}:\n\n" - conversation_history_str += self.global_history["GLOBAL"][phase] - conversation_history_str += "\n" - if self.history_by_power[power_name]: - conversation_history_str += "PRIVATE:\n" - for phase in phases_to_report: - if phase in self.history_by_power[power_name]: - conversation_history_str += f"\n{phase}:\n" - for power in self.history_by_power[power_name][phase].keys(): - conversation_history_str += f"\n {power}:\n\n" - conversation_history_str += self.history_by_power[power_name][ - phase - ][power] - - return conversation_history_str diff --git a/ai_diplomacy/game_history.py b/ai_diplomacy/game_history.py new file mode 100644 index 0000000..db740f5 --- /dev/null +++ b/ai_diplomacy/game_history.py @@ -0,0 +1,157 @@ +from dotenv import load_dotenv +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +logger = logging.getLogger("utils") +logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +load_dotenv() + + +@dataclass +class Message: + sender: str + recipient: str + content: str + + +@dataclass +class Phase: + name: str # e.g. "SPRING 1901" + messages: List[Message] = field(default_factory=list) + orders_by_power: Dict[str, List[str]] = field( + default_factory=lambda: defaultdict(list) + ) + results_by_power: Dict[str, List[List[str]]] = field( + default_factory=lambda: defaultdict(list) + ) + + def add_message(self, sender: str, recipient: str, content: str): + self.messages.append( + Message(sender=sender, recipient=recipient, content=content) + ) + + def add_orders(self, power: str, orders: List[str], results: List[List[str]]): + self.orders_by_power[power].extend(orders) + # Make sure results has the same length as orders, if not, pad with empty lists + if len(results) < len(orders): + results.extend([[] for _ in range(len(orders) - len(results))]) + self.results_by_power[power].extend(results) + + def get_global_messages(self) -> str: + result = "" + for msg in self.messages: + if msg.recipient == "GLOBAL": + result += f" {msg.sender}: {msg.content}\n" + return result + + def get_private_messages(self, power: str) -> Dict[str, str]: + conversations = defaultdict(str) + for msg in self.messages: + if msg.sender == power and msg.recipient != "GLOBAL": + conversations[msg.recipient] += f" {power}: {msg.content}\n" + elif msg.recipient == power: + conversations[msg.sender] += f" {msg.sender}: {msg.content}\n" + return conversations + + def get_all_orders_formatted(self) -> str: + if not self.orders_by_power: + return "" + + result = f"\nOrders for {self.name}:\n" + for power, orders in self.orders_by_power.items(): + result += f"{power}:\n" + results = self.results_by_power.get(power, []) + for i, order in enumerate(orders): + if i < len(results) and results[i]: + # Join multiple results with commas + result_str = f" ({', '.join(results[i])})" + else: + result_str = " (successful)" + result += f" {order}{result_str}\n" + result += "\n" + return result + + +@dataclass +class GameHistory: + phases: List[Phase] = field(default_factory=list) + + def add_phase(self, phase_name: str) -> Phase: + # Check if phase already exists + for phase in self.phases: + if phase.name == phase_name: + return phase + + # Create new phase + new_phase = Phase(name=phase_name) + self.phases.append(new_phase) + return new_phase + + def add_message(self, phase_name: str, sender: str, recipient: str, content: str): + phase = self.add_phase(phase_name) + phase.add_message(sender, recipient, content) + + def add_orders( + self, phase_name: str, power: str, orders: List[str], results: List[List[str]] + ): + phase = self.add_phase(phase_name) + phase.add_orders(power, orders, results) + + def get_game_history(self, power_name: str, num_prev_phases: int = 5) -> str: + if not self.phases: + return "" + + phases_to_report = self.phases[-num_prev_phases:] + game_history_str = "" + + # Add GLOBAL section + has_global = any(phase.get_global_messages() for phase in phases_to_report) + if has_global: + game_history_str += "GLOBAL:\n" + for phase in phases_to_report: + global_msgs = phase.get_global_messages() + if global_msgs: + game_history_str += f"\n{phase.name}:\n\n" + game_history_str += global_msgs + game_history_str += "\n" + + # Add PRIVATE section + has_private = any( + phase.get_private_messages(power_name) for phase in phases_to_report + ) + if has_private: + game_history_str += "PRIVATE:\n" + for phase in phases_to_report: + private_msgs = phase.get_private_messages(power_name) + if private_msgs: + game_history_str += f"\n{phase.name}:\n" + for other_power, messages in private_msgs.items(): + game_history_str += f"\n {other_power}:\n\n" + game_history_str += messages + + # Add ORDERS section - including all powers' orders with results + if any(phase.orders_by_power for phase in phases_to_report): + game_history_str += "\nORDERS:\n" + for phase in phases_to_report: + if phase.orders_by_power: + game_history_str += f"\n{phase.name}:\n" + for power, orders in phase.orders_by_power.items(): + game_history_str += f"{power}:\n" + results = phase.results_by_power.get(power, []) + for i, order in enumerate(orders): + if ( + i < len(results) + and results[i] + and not all(r == "" for r in results[i]) + ): + # Join multiple results with commas + result_str = f" ({', '.join(results[i])})" + else: + result_str = " (successful)" + game_history_str += f" {order}{result_str}\n" + game_history_str += "\n" + + return game_history_str diff --git a/ai_diplomacy/negotiations.py b/ai_diplomacy/negotiations.py index d5f680d..1d928ba 100644 --- a/ai_diplomacy/negotiations.py +++ b/ai_diplomacy/negotiations.py @@ -13,7 +13,7 @@ logging.basicConfig(level=logging.INFO) load_dotenv() -def conduct_negotiations(game, conversation_history, model_error_stats, max_rounds=3): +def conduct_negotiations(game, game_history, model_error_stats, max_rounds=3): """ Conducts a round-robin conversation among all non-eliminated powers. Each power can send up to 'max_rounds' messages, choosing between private @@ -23,6 +23,7 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun # Conversation messages are kept in a local list ONLY to build conversation_so_far text. conversation_messages = [] + return conversation_messages active_powers = [ p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() @@ -46,9 +47,8 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun board_state=board_state, power_name=power_name, possible_orders=possible_orders, - conversation_history=conversation_history, + game_history=game_history, game_phase=game.current_short_phase, - phase_summaries=game.phase_summaries, active_powers=active_powers, ) @@ -62,7 +62,7 @@ def conduct_negotiations(game, conversation_history, model_error_stats, max_roun message=message["content"], ) game.add_message(diplo_message) - conversation_history.add_message( + game_history.add_message( game.current_short_phase, power_name, message ) conversation_messages.append(message) diff --git a/ai_diplomacy/prompts/context_prompt.txt b/ai_diplomacy/prompts/context_prompt.txt index 70cab54..1b48b36 100644 --- a/ai_diplomacy/prompts/context_prompt.txt +++ b/ai_diplomacy/prompts/context_prompt.txt @@ -20,13 +20,9 @@ Possible coasts at each location: All supply centers on the map: {game_map_scs} -**HISTORICAL SUMMARIES** +**GAME HISTORY** -{historical_summaries} - -**CONVERSATION HISTORY** - -{conversation_history} +{game_history} **CURRENT CONTEXT** diff --git a/ai_diplomacy/prompts/system_prompt.txt b/ai_diplomacy/prompts/system_prompt.txt index c56b0ec..755b6d1 100644 --- a/ai_diplomacy/prompts/system_prompt.txt +++ b/ai_diplomacy/prompts/system_prompt.txt @@ -1,30 +1,12 @@ -You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. +You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. Your goal is to win the game by capturing supply centers, growing your army, and taking over the map. Be aggressive. You will be given: • Which power you are controlling. • The current phase (e.g. S1901M). • Details about the map. • Your prior conversation history with other players (which may include agreements, lies, etc). -• Your units and the possible orders for each. Always refer to these possible_orders. -• Summaries of past phases (including bounces, voids, and the reasons). +• Your units and the possible orders you may make. Always refer to these possible_orders. +• Results of past phases (including bounces and voids) • A list of enemy units and centers. -• A “convoy_paths_possible” listing, if relevant, describing possible convoy routes (e.g. [("A NAP", {F ION, F TYS}, "TUN"), ...]). -**Your tasks**: -1) Reflect on your strategic goals and the current board situation. -2) **Strategize** to increase your supply centers, defend your existing centers, and expand influence. -3) **Coordinate** your orders so they do not produce internal contradictions. For example, do not support a move that you or your ally are not making. -4) Check if any of your previous moves were blocked or voided. Learn from that: - - If you bounced due to equal force, consider using support or picking a different target. - - If you had an invalid adjacency or a mismatch in support, fix it this turn. -5) Evaluate if you can use a convoy; consult “convoy_paths_possible” to see if a valid route exists. -6) Depending on what phase of the game we are in, either send messages to other players or propose a set of final orders. Always refer to the possible_orders list before committing to an order. -7) **Anticipate** that other powers may lie or might try to bounce you. If you suspect a bounce, consider using support or a safer move. -8) **Return a valid, consistent set of final orders or messages** in the required JSON format. - -Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine. - -- If you use a support order, ensure you (or an ally) actually have a matching move that it supports. Reference the map if necessary to make sure that the move is possible. -- If you do a convoy, ensure the fleets and adjacency match the “convoy_paths_possible” data. -- Attempt to avoid guaranteed bounces unless you see a diplomatic reason to do so. -- If you suspect an enemy might also move to your target, consider using support or picking a safer approach. \ No newline at end of file +Remember that while your private chain-of-thought can consider your in-depth reasoning about possible outcomes, **only** the “PARSABLE OUTPUT” (your final orders or messages) will be used by the game engine. \ No newline at end of file diff --git a/ai_diplomacy/utils.py b/ai_diplomacy/utils.py index 85c8a2e..88c044f 100644 --- a/ai_diplomacy/utils.py +++ b/ai_diplomacy/utils.py @@ -45,8 +45,7 @@ def get_valid_orders( board_state, power_name, possible_orders, - conversation_history, - phase_summaries, + game_history, model_error_stats, ): """ @@ -61,8 +60,7 @@ def get_valid_orders( board_state=board_state, power_name=power_name, possible_orders=possible_orders, - conversation_text=conversation_history, - phase_summaries=phase_summaries, + conversation_text=game_history, model_error_stats=model_error_stats, ) diff --git a/lm_game.py b/lm_game.py index c3a9c8e..a9cbb2a 100644 --- a/lm_game.py +++ b/lm_game.py @@ -20,7 +20,7 @@ from ai_diplomacy.utils import ( assign_models_to_powers, ) from ai_diplomacy.negotiations import conduct_negotiations -from ai_diplomacy.conversation_history import ConversationHistory +from ai_diplomacy.game_history import GameHistory dotenv.load_dotenv() @@ -90,7 +90,7 @@ def main(): # Create a fresh Diplomacy game game = Game() - conversation_history = ConversationHistory() + game_history = GameHistory() # Ensure game has phase_summaries attribute if not hasattr(game, "phase_summaries"): @@ -150,7 +150,7 @@ def main(): if game.current_short_phase.endswith("M"): logger.info("Starting negotiation phase block...") conversation_messages = conduct_negotiations( - game, conversation_history, model_error_stats, max_rounds=10 + game, game_history, model_error_stats, max_rounds=10 ) else: conversation_messages = [] @@ -180,8 +180,7 @@ def main(): board_state, power_name, possible_orders, - conversation_history, - game.phase_summaries, + game_history, model_error_stats, ) futures[future] = power_name @@ -209,6 +208,26 @@ def main(): sys, usr, summary_model ) ) + # Add orders to game history + for power_name in game.order_history[current_phase]: + orders = game.order_history[current_phase][power_name] + results = [] + for order in orders: + # Example move: "A PAR H" -> unit="A PAR", order_part="H" + tokens = order.split(" ", 2) + if len(tokens) < 3: + continue + unit = " ".join(tokens[:2]) # e.g. "A PAR" + order_part = tokens[2] # e.g. "H" or "S A MAR" + results.append( + [str(x) for x in game.result_history[current_phase][unit]] + ) + game_history.add_orders( + current_phase, + power_name, + game.order_history[current_phase][power_name], + results, + ) logger.info("Phase complete.\n") # Retrieve and log the summary of the phase