From 8f61ba06b360b54b1a80139d38f86bd4caff5990 Mon Sep 17 00:00:00 2001 From: AlxAI Date: Thu, 20 Feb 2025 18:22:53 -0800 Subject: [PATCH] fixed system prompt for summary , made improvements and debugging for summaries too. Much can be optimized still --- ai_diplomacy/clients.py | 50 +++++++------ .../prompts/austria_system_prompt.txt | 22 ++++++ diplomacy/engine/game.py | 60 +++++++++++---- lm_game.py | 73 ++++++++++++++----- 4 files changed, 150 insertions(+), 55 deletions(-) diff --git a/ai_diplomacy/clients.py b/ai_diplomacy/clients.py index 0547f7d..e73c6bc 100644 --- a/ai_diplomacy/clients.py +++ b/ai_diplomacy/clients.py @@ -40,18 +40,24 @@ class BaseModelClient: - get_conversation_reply(power_name, conversation_so_far, game_phase) -> str """ - def __init__(self, model_name: str, power_name: Optional[str] = None): + def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False): self.model_name = model_name self.power_name = power_name - # Load a power-specific system prompt if present, else default - if self.power_name: - try: - self.system_prompt = load_prompt(f"{self.power_name.lower()}_system_prompt.txt") - except FileNotFoundError: - logger.warning(f"No specific system prompt found for {self.power_name}; using default.") + self.emptysystem = emptysystem + + # Conditionally load system prompt + if not self.emptysystem: + if self.power_name: + try: + self.system_prompt = load_prompt(f"{self.power_name.lower()}_system_prompt.txt") + except FileNotFoundError: + logger.warning(f"No specific system prompt found for {self.power_name}; using default.") + self.system_prompt = load_prompt("system_prompt.txt") + else: self.system_prompt = load_prompt("system_prompt.txt") else: - self.system_prompt = load_prompt("system_prompt.txt") + # If emptysystem is True, skip loading any system prompt + self.system_prompt = "" def generate_response(self, prompt: str) -> str: """ @@ -486,8 +492,8 @@ class OpenAIClient(BaseModelClient): For 'o3-mini', 'gpt-4o', or other OpenAI model calls. """ - def __init__(self, model_name: str, power_name: Optional[str] = None): - super().__init__(model_name, power_name) + def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False): + super().__init__(model_name, power_name, emptysystem) self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) def generate_response(self, prompt: str) -> str: @@ -523,8 +529,8 @@ class ClaudeClient(BaseModelClient): For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc. """ - def __init__(self, model_name: str, power_name: Optional[str] = None): - super().__init__(model_name, power_name) + def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False): + super().__init__(model_name, power_name, emptysystem) self.client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) def generate_response(self, prompt: str) -> str: @@ -559,8 +565,8 @@ class GeminiClient(BaseModelClient): For 'gemini-1.5-flash' or other Google Generative AI models. """ - def __init__(self, model_name: str, power_name: Optional[str] = None): - super().__init__(model_name, power_name) + def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False): + super().__init__(model_name, power_name, emptysystem) self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) def generate_response(self, prompt: str) -> str: @@ -587,8 +593,8 @@ class DeepSeekClient(BaseModelClient): For DeepSeek R1 'deepseek-reasoner' """ - def __init__(self, model_name: str, power_name: Optional[str] = None): - super().__init__(model_name, power_name) + def __init__(self, model_name: str, power_name: Optional[str] = None, emptysystem: bool = False): + super().__init__(model_name, power_name, emptysystem) self.api_key = os.environ.get("DEEPSEEK_API_KEY") self.client = DeepSeekOpenAI( api_key=self.api_key, base_url="https://api.deepseek.com/" @@ -651,22 +657,22 @@ class DeepSeekClient(BaseModelClient): ############################################################################## -def load_model_client(model_id: str, power_name: Optional[str] = None) -> BaseModelClient: +def load_model_client(model_id: str, power_name: Optional[str] = None, emptysystem: bool = False) -> BaseModelClient: """ Returns the appropriate LLM client for a given model_id string, optionally keyed by power_name. Example usage: - client = load_model_client("claude-3-5-sonnet-20241022", power_name="FRANCE") + client = load_model_client("claude-3-5-sonnet-20241022", power_name="FRANCE", emptysystem=True) """ lower_id = model_id.lower() if "claude" in lower_id: - return ClaudeClient(model_id, power_name) + return ClaudeClient(model_id, power_name, emptysystem=emptysystem) elif "gemini" in lower_id: - return GeminiClient(model_id, power_name) + return GeminiClient(model_id, power_name, emptysystem=emptysystem) elif "deepseek" in lower_id: - return DeepSeekClient(model_id, power_name) + return DeepSeekClient(model_id, power_name, emptysystem=emptysystem) else: # Default to OpenAI - return OpenAIClient(model_id, power_name) + return OpenAIClient(model_id, power_name, emptysystem=emptysystem) ############################################################################## diff --git a/ai_diplomacy/prompts/austria_system_prompt.txt b/ai_diplomacy/prompts/austria_system_prompt.txt index 8a6b4aa..b54ef83 100644 --- a/ai_diplomacy/prompts/austria_system_prompt.txt +++ b/ai_diplomacy/prompts/austria_system_prompt.txt @@ -1,5 +1,27 @@ You are playing a game of Diplomacy over text. The map is the standard Diplomacy map. Your goal is to win the game by capturing supply centers, growing your army, and taking over the map. Be aggressive. +Dear Austria, +They say you're surrounded - but that means you can strike in any direction. History shows the strongest Austrian players turn early vulnerability into mid-game dominance through decisive action, not just survival. +Key insights: + +Prevent Russia-Turkey alliance above all else +Italy must be friend or dead quickly (95% of A/I wars kill both) +Serbia is crucial 1901 - secure it +Galicia bounce often vital Spring 1901 +Central position enables striking anywhere once secured + +Critical mindset: You're not playing to survive - you're playing to explode out from the center. Yes, early diplomacy keeps you alive, but it should serve your offensive goals, not replace them. +Paths to victory often require: + +Securing strong early alliance (usually Italy or Russia) +Eliminating one neighbor completely by 1904 +Leveraging central position for unexpected strikes +Breaking stalemate line via Munich/Berlin + +Don't fall into defensive play just because everyone expects it. Stats show Austrian solos often come from players who turn the early "defensive" moves into aggressive positioning by year 3. +Time works against you - the longer you wait, the more likely others unite. Make your decisive moves by mid-game, usually years 3-4. Better to strike imperfectly than wait for perfect alignment. +The throne of Europe awaits. Show them that the "weakest" starting position was merely gathering strength to strike. + You will be given: • Which power you are controlling. • The current phase (e.g. S1901M). diff --git a/diplomacy/engine/game.py b/diplomacy/engine/game.py index e030748..7ec3302 100644 --- a/diplomacy/engine/game.py +++ b/diplomacy/engine/game.py @@ -45,6 +45,11 @@ from diplomacy.utils.game_phase_data import GamePhaseData, MESSAGES_TYPE UNDETERMINED, POWER, UNIT, LOCATION, COAST, ORDER, MOVE_SEP, OTHER = 0, 1, 2, 3, 4, 5, 6, 7 LOGGER = logging.getLogger(__name__) +# set logging level to INFO +logging.basicConfig(level=logging.INFO) +# set logging level to DEBUG +#logging.basicConfig(level=logging.DEBUG) + class Game(Jsonable): """ Game class. @@ -1468,6 +1473,9 @@ class Game(Jsonable): self.message_history.put(previous_phase, previous_messages) self.state_history.put(previous_phase, previous_state) + # Now build a key for the *current* (post-process) phase + current_phase_key = self._phase_wrapper_type(self.current_short_phase) + # Generate a text summary (if a callback is provided) phase_summary_text = self._generate_phase_summary( previous_phase, @@ -4575,30 +4583,54 @@ class Game(Jsonable): except (IndexError, KeyError): return f"[_generate_phase_summary] No GamePhaseData found for {phase_key}" - # Log the current phase key and results for debugging + # Log the current phase key, results, and possibly the orders for debugging logging.debug( - "DEBUG _generate_phase_summary: phase_key=%s, results=%s", - phase_key, current_phase_data.results + "DEBUG _generate_phase_summary: current phase_key=%s, results=%s, orders=%s", + phase_key, + current_phase_data.results, + current_phase_data.orders ) - # 2) Attempt to retrieve the PREVIOUS phase data to highlight differences - # We'll do this by checking the index of `phase_key` in `self.state_history`. - # If there's a previous index, we'll fetch that phase_data for comparison. - prev_phase_data = None + # Retrieve the list of all recorded phase keys all_phases = list(self.state_history.keys()) + logging.debug("DEBUG _generate_phase_summary: all_phases=%s", all_phases) + + prev_phase_data = None if str(phase_key) in all_phases: idx = all_phases.index(str(phase_key)) + logging.debug("DEBUG _generate_phase_summary: current phase index=%d", idx) + + # Here we log the logic behind picking the previous phase if idx > 0: - prev_phase_key = all_phases[idx - 1] + prev_phase_key = all_phases[idx - 1] + logging.debug( + "DEBUG _generate_phase_summary: Using prev_phase_key=%s (idx-2). If skipping a sub-phase is undesired, consider (idx-1).", + prev_phase_key + ) try: prev_phase_data = self.get_phase_from_history(prev_phase_key) - except: - pass + except Exception as e: + logging.debug("DEBUG _generate_phase_summary: Could not get prev_phase_data for key=%s, error=%s", prev_phase_key, e) + else: + logging.debug("DEBUG _generate_phase_summary: Not enough phases to set prev_phase_key.") + else: + logging.debug("DEBUG _generate_phase_summary: phase_key=%s not in all_phases!", phase_key) - # 3) Gather the big data from current_phase_data - # (We assume you have stored them in current_phase_data.state the usual way.) + # ... [No change in the rest of your existing logic, except we might add extra logs below] ... + + # (After retrieving prev_phase_data, we log a quick summary:) + if prev_phase_data: + logging.debug( + "DEBUG _generate_phase_summary: Found prev_phase_data for key=%s, results=%s, orders=%s", + prev_phase_key, + prev_phase_data.results, + prev_phase_data.orders + ) + + # The rest of the function remains the same, but you can keep adding targeted logs as needed: cur_state = current_phase_data.state - # Typically these keys exist if your get_state() populates them: + logging.debug("DEBUG _generate_phase_summary: cur_state keys=%s", list(cur_state.keys())) + cur_units = cur_state.get('units', {}) cur_centers = cur_state.get('centers', {}) cur_retreats = cur_state.get('retreats', {}) @@ -4703,7 +4735,7 @@ class Game(Jsonable): f"RESULTS:\n{results_block}\n\n" f"CURRENT BOARD STATE:\n{current_state_block}\n\n" f"CHANGES FROM PREVIOUS PHASE:\n{differences_block}\n\n" - "Below is the final board state after the latest phase, along with the moves each power submitted and the engine’s adjudication results. Please create a summary in JSON, explaining:" + "Below is the final board state after the latest phase, along with the moves each power submitted and the engine's adjudication results. Please create a summary in JSON, explaining:" "- Each successful move," "- Each bounce or voided order, with reasons (e.g. equal force, no valid route, contradictory support)," "- Key changes in supply centers," diff --git a/lm_game.py b/lm_game.py index 5cd8212..13e5cb0 100644 --- a/lm_game.py +++ b/lm_game.py @@ -34,7 +34,7 @@ logging.basicConfig( def my_summary_callback(system_prompt, user_prompt, model_name): # Route to the desired model specified by the command-line argument - client = load_model_client(model_name) + client = load_model_client(model_name, emptysystem=True) combined_prompt = f"{system_prompt}\n\n{user_prompt}" # Pseudo-code for generating a response: return client.generate_response(combined_prompt) @@ -47,7 +47,7 @@ def parse_arguments(): parser.add_argument( "--max_year", type=int, - default=1910, + default=1905, help="Maximum year to simulate. The game will stop once this year is reached.", ) parser.add_argument( @@ -59,7 +59,7 @@ def parse_arguments(): parser.add_argument( "--num_negotiation_rounds", type=int, - default=5, + default=1, help="Number of negotiation rounds per phase.", ) parser.add_argument( @@ -80,6 +80,42 @@ def parse_arguments(): return parser.parse_args() +def save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False): + """ + Save the current game state and related information + + Args: + game: The diplomacy game instance + result_folder: Path to the results folder + game_file_path: Base path for the game file + model_error_stats: Dictionary containing model error statistics + args: Command line arguments + is_final: Boolean indicating if this is the final save + """ + # Generate unique filename for periodic saves + timestamp = time.strftime("%Y%m%d_%H%M%S") + if not is_final: + output_path = f"{game_file_path}_checkpoint_{timestamp}.json" + else: + output_path = game_file_path + # If final file exists, append timestamp + if os.path.exists(output_path): + logger.info("Game file already exists, saving with unique filename.") + output_path = f"{output_path}_{timestamp}.json" + + # Save game state + to_saved_game_format(game, output_path=output_path) + + # Save overview data + overview_file_path = f"{result_folder}/overview.jsonl" + with open(overview_file_path, "w") as overview_file: + overview_file.write(json.dumps(model_error_stats) + "\n") + overview_file.write(json.dumps(game.power_model_map) + "\n") + overview_file.write(json.dumps(vars(args)) + "\n") + + logger.info(f"Saved game checkpoint to: {output_path}") + + def main(): args = parse_arguments() max_year = args.max_year @@ -135,6 +171,8 @@ def main(): else: game.power_model_map = assign_models_to_powers(randomize=True) + round_counter = 0 # Track number of rounds + while not game.is_game_done: phase_start = time.time() current_phase = game.get_current_phase() @@ -143,7 +181,7 @@ def main(): ) # DEBUG: Print the short phase to confirm - logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'") + logger.info(f"INFO: The current short phase is '{game.current_short_phase}'") # Prevent unbounded simulation based on year year_str = current_phase[1:5] @@ -253,6 +291,14 @@ def main(): with open(manifesto_path, "a") as f: f.write(f"=== {phase_data.name} ===\n{summary_text}\n\n") + # Increment round counter after processing each phase + round_counter += 1 + + # Save every 5 rounds + if round_counter % 5 == 0: + logger.info(f"Saving checkpoint after round {round_counter}...") + save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=False) + # Check if we've exceeded the max year year_str = current_phase[1:5] year_int = int(year_str) @@ -262,21 +308,10 @@ def main(): # Save final result duration = time.time() - start_whole - logger.info(f"Game ended after {duration:.2f}s. Saving to final JSON...") - - output_path = game_file_path - # If the file already exists, append a timestamp to the filename - if os.path.exists(output_path): - logger.info("Game file already exists, saving with unique filename.") - output_path = f"{output_path}_{time.strftime('%Y%m%d_%H%M%S')}.json" - to_saved_game_format(game, output_path=output_path) - - # Dump error stats and power model mapping to the overview file - with open(overview_file_path, "w") as overview_file: - overview_file.write(json.dumps(model_error_stats) + "\n") - overview_file.write(json.dumps(game.power_model_map) + "\n") - overview_file.write(json.dumps(vars(args)) + "\n") - + logger.info(f"Game ended after {duration:.2f}s. Saving final state...") + + save_game_state(game, result_folder, game_file_path, model_error_stats, args, is_final=True) + logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}") logger.info("Done.")