diff --git a/lm_service_versus.py b/ai_diplomacy/clients.py similarity index 72% rename from lm_service_versus.py rename to ai_diplomacy/clients.py index a58e2dd..4a5b723 100644 --- a/lm_service_versus.py +++ b/ai_diplomacy/clients.py @@ -1,5 +1,6 @@ import os import json +from json import JSONDecodeError import re import logging import ast @@ -7,24 +8,20 @@ import ast from typing import List, Dict, Optional from dotenv import load_dotenv -# Anthropics import anthropic -# Google Generative AI -# Set gemini to more verbose -os.environ['GRPC_PYTHON_LOG_LEVEL'] = '10' +os.environ["GRPC_PYTHON_LOG_LEVEL"] = "10" import google.generativeai as genai # Import after setting log level - -# DeepSeek from openai import OpenAI as DeepSeekOpenAI # set logger back to just info -logger = logging.getLogger('lm_service_versus') +logger = logging.getLogger("client") logger.setLevel(logging.INFO) logging.basicConfig(level=logging.INFO) load_dotenv() + ############################################################################## # 1) Base Interface ############################################################################## @@ -41,6 +38,7 @@ class BaseModelClient: self.model_name = model_name self.system_prompt_response = load_prompt("system_prompt_response.txt") self.system_prompt_conversation = load_prompt("system_prompt_conversation.txt") + def generate_response(self, prompt: str) -> str: """ Returns a raw string from the LLM. @@ -48,12 +46,11 @@ class BaseModelClient: """ raise NotImplementedError("Subclasses must implement generate_response().") - def build_prompt( - self, - board_state, - power_name: str, - possible_orders: Dict[str, List[str]], + self, + board_state, + power_name: str, + possible_orders: Dict[str, List[str]], conversation_text: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: @@ -75,7 +72,6 @@ class BaseModelClient: enemy_units[power] = info enemy_centers[power] = board_state["centers"].get(power, []) - summary = ( f"Power: {power_name}\n" f"Current Phase: {year_phase}\n" @@ -101,43 +97,52 @@ class BaseModelClient: historical_summaries = "\n(No historical summaries provided)\n" prompt = ( - "Relevant Conversation:\n" + conversation_text + "\n\n" - + "Historical Summaries:\n" + historical_summaries + "\n\n" - + summary + few_shot_example + "\n" + "Relevant Conversation:\n" + + conversation_text + + "\n\n" + + "Historical Summaries:\n" + + historical_summaries + + "\n\n" + + summary + + few_shot_example + + "\n" + instructions ) return prompt def get_orders( - self, - board_state, - power_name: str, - possible_orders: Dict[str, List[str]], + self, + board_state, + power_name: str, + possible_orders: Dict[str, List[str]], conversation_text: str, phase_summaries: Optional[Dict[str, str]] = None, - model_error_stats=None # New optional param + model_error_stats=None, # New optional param ) -> List[str]: """ 1) Builds the prompt with conversation context if available 2) Calls LLM 3) Parses JSON block """ - prompt = self.build_prompt(board_state, power_name, possible_orders, conversation_text, phase_summaries) + prompt = self.build_prompt( + board_state, power_name, possible_orders, conversation_text, phase_summaries + ) raw_response = "" try: raw_response = self.generate_response(prompt) logger.info(f"[{self.model_name}] prompted for {power_name}:\n{prompt}") - logger.info(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}") + logger.info( + f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}" + ) # Attempt to parse the final "orders" from the LLM move_list = self._extract_moves(raw_response, power_name) - print(move_list) if not move_list: - import pdb; pdb.set_trace() - - logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.") + logger.warning( + f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback." + ) if model_error_stats is not None: model_error_stats[self.model_name]["order_decoding_errors"] += 1 return self.fallback_orders(possible_orders) @@ -152,50 +157,58 @@ class BaseModelClient: def _extract_moves(self, raw_response: str, power_name: str) -> Optional[List[str]]: """ Attempt multiple parse strategies to find JSON array of moves. - + 1. Regex for PARSABLE OUTPUT lines. 2. If that fails, also look for fenced code blocks with { ... }. 3. Attempt bracket-based fallback if needed. - + Returns a list of move strings or None if everything fails. """ # 1) Regex for "PARSABLE OUTPUT:{...}" pattern = r"PARSABLE OUTPUT:\s*(\{[\s\S]*\})" matches = re.search(pattern, raw_response, re.DOTALL) - + if not matches: # Some LLMs might not put the colon or might have triple backtick fences. - logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.") - + logger.debug( + f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns." + ) + # 1b) Check for inline JSON after "PARSABLE OUTPUT" pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$" matches = re.search(pattern_alt, raw_response, re.DOTALL) - - if not matches: - logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying triple-backtick code fences.") + + if not matches: + logger.debug( + f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying triple-backtick code fences." + ) # 2) If still no match, check for triple-backtick code fences containing JSON if not matches: code_fence_pattern = r"```json\s*(\{.*?\})\s*```" matches = re.search(code_fence_pattern, raw_response, re.DOTALL) if matches: - logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.") - + logger.debug( + f"[{self.model_name}] Found triple-backtick JSON block for {power_name}." + ) + # 3) Attempt to parse JSON if we found anything json_text = None if matches: # Add braces back around the captured group - if matches.group(1).strip().startswith(r"{{"): + if matches.group(1).strip().startswith(r"{{"): json_text = matches.group(1).strip()[1:-1] elif matches.group(1).strip().startswith(r"{"): json_text = matches.group(1).strip() - else: - json_text = "{%s}" % matches.group(1).strip - + else: + json_text = "{%s}" % matches.group(1).strip + json_text = json_text.strip() if not json_text: - logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.") + logger.debug( + f"[{self.model_name}] No JSON text found in LLM response for {power_name}." + ) return None # 3a) Try JSON loading @@ -203,7 +216,9 @@ class BaseModelClient: data = json.loads(json_text) return data.get("orders", None) except json.JSONDecodeError as e: - logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying bracket fallback.") + logger.warning( + f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying bracket fallback." + ) # 3b) Attempt bracket fallback: we look for the substring after "orders" # E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors. @@ -217,12 +232,16 @@ class BaseModelClient: if isinstance(moves, list): return moves except Exception as e2: - logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}") + logger.warning( + f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}" + ) # If all attempts failed return None - def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> List[str]: + def _validate_orders( + self, moves: List[str], possible_orders: Dict[str, List[str]] + ) -> List[str]: """ Filter out invalid moves, fill missing with HOLD, else fallback. """ @@ -249,7 +268,9 @@ class BaseModelClient: for loc, orders_list in possible_orders.items(): if loc not in used_locs and orders_list: hold_candidates = [o for o in orders_list if o.endswith("H")] - validated.append(hold_candidates[0] if hold_candidates else orders_list[0]) + validated.append( + hold_candidates[0] if hold_candidates else orders_list[0] + ) if not validated: logger.warning(f"[{self.model_name}] All moves invalid, fallback.") @@ -270,9 +291,9 @@ class BaseModelClient: return fallback def build_conversation_reply( - self, - power_name: str, - conversation_so_far: str, + self, + power_name: str, + conversation_so_far: str, game_phase: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: @@ -283,26 +304,34 @@ class BaseModelClient: power_name=power_name, game_phase=game_phase, phase_summaries=phase_summaries, - conversation_so_far=conversation_so_far + conversation_so_far=conversation_so_far, ) - - def generate_conversation_reply(self, power_name: str, conversation_so_far: str, game_phase: str) -> str: + + def generate_conversation_reply( + self, power_name: str, conversation_so_far: str, game_phase: str + ) -> str: """ Overwritten by subclasses. """ - raise NotImplementedError("Subclasses must implement generate_conversation_reply().") + raise NotImplementedError( + "Subclasses must implement generate_conversation_reply()." + ) + ############################################################################## # 2) Concrete Implementations ############################################################################## + class OpenAIClient(BaseModelClient): """ For 'o3-mini', 'gpt-4o', or other OpenAI model calls. """ + def __init__(self, model_name: str): super().__init__(model_name) from openai import OpenAI # Import the new client + self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) def generate_response(self, prompt: str) -> str: @@ -313,35 +342,45 @@ class OpenAIClient(BaseModelClient): model=self.model_name, messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], ) if not response or not hasattr(response, "choices") or not response.choices: - logger.warning(f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty.") + logger.warning( + f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty." + ) return "" return response.choices[0].message.content.strip() except json.JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}") + logger.error( + f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}" + ) return "" except Exception as e: - logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") + logger.error( + f"[{self.model_name}] Unexpected error in generate_response: {e}" + ) return "" def get_conversation_reply( - self, - power_name: str, - conversation_so_far: str, - game_phase: str, + self, + power_name: str, + conversation_so_far: str, + game_phase: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: """ Produces a single message with the appropriate JSON format. """ - import json from json.decoder import JSONDecodeError + # load the system prompt but formatted with the power name and game phase - system_prompt = self.system_prompt_conversation.format(power_name=power_name, game_phase=game_phase) - conversation_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries) + system_prompt = self.system_prompt_conversation.format( + power_name=power_name, game_phase=game_phase + ) + conversation_prompt = self.build_conversation_reply( + power_name, conversation_so_far, game_phase, phase_summaries + ) try: # Perform the request @@ -349,35 +388,39 @@ class OpenAIClient(BaseModelClient): model=self.model_name, messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": conversation_prompt} + {"role": "user", "content": conversation_prompt}, ], - max_completion_tokens=2000 + max_completion_tokens=2000, ) # If there's no valid response or choices, return empty if not response or not hasattr(response, "choices") or not response.choices: - logger.warning(f"[{self.model_name}] Empty or invalid response for {power_name}. Returning empty.") + logger.warning( + f"[{self.model_name}] Empty or invalid response for {power_name}. Returning empty." + ) return "" # Attempt to parse the content (OpenAI library usually does this, but we add a safety net) return response.choices[0].message.content.strip() except JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decoding failed for {power_name}: {json_err}") + logger.error( + f"[{self.model_name}] JSON decoding failed for {power_name}: {json_err}" + ) return "" # Fallback except Exception as e: logger.error(f"[{self.model_name}] Unexpected error for {power_name}: {e}") return "" + class ClaudeClient(BaseModelClient): """ For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc. """ + def __init__(self, model_name: str): super().__init__(model_name) - self.client = anthropic.Anthropic( - api_key=os.environ.get("ANTHROPIC_API_KEY") - ) + self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) def generate_response(self, prompt: str) -> str: system_prompt = self.system_prompt_response @@ -387,52 +430,64 @@ class ClaudeClient(BaseModelClient): model=self.model_name, max_tokens=2000, system=system_prompt, # system is now a top-level parameter - messages=[ - {"role": "user", "content": prompt} - ] + messages=[{"role": "user", "content": prompt}], ) if not response.content: - logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.") + logger.warning( + f"[{self.model_name}] Empty content in Claude generate_response. Returning empty." + ) return "" return response.content[0].text.strip() if response.content else "" except json.JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}") + logger.error( + f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}" + ) return "" except Exception as e: - logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") + logger.error( + f"[{self.model_name}] Unexpected error in generate_response: {e}" + ) return "" def get_conversation_reply( - self, - power_name: str, - conversation_so_far: str, - game_phase: str, + self, + power_name: str, + conversation_so_far: str, + game_phase: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: system_prompt = f"You are playing as {power_name} in this Diplomacy negotiation phase {game_phase}." - user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries) + user_prompt = self.build_conversation_reply( + power_name, conversation_so_far, game_phase, phase_summaries + ) try: response = self.client.messages.create( model=self.model_name, system=system_prompt, messages=[{"role": "user", "content": user_prompt}], - max_tokens=2000 + max_tokens=2000, ) if not response.content: - logger.warning(f"[{self.model_name}] No content in Claude conversation. Returning empty.") + logger.warning( + f"[{self.model_name}] No content in Claude conversation. Returning empty." + ) return "" return response.content[0].text.strip() except json.JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decoding failed in conversation: {json_err}") + logger.error( + f"[{self.model_name}] JSON decoding failed in conversation: {json_err}" + ) return "" except Exception as e: logger.error(f"[{self.model_name}] Unexpected error in conversation: {e}") return "" + class GeminiClient(BaseModelClient): """ For 'gemini-1.5-flash' or other Google Generative AI models. """ + def __init__(self, model_name: str): super().__init__(model_name) genai.configure(api_key=os.environ.get("GEMINI_API_KEY")) @@ -444,15 +499,16 @@ class GeminiClient(BaseModelClient): def generate_response(self, prompt: str) -> str: system_prompt = self.system_prompt_response full_prompt = system_prompt + prompt - + try: model = genai.GenerativeModel( - self.model_name, - generation_config=self.generation_config + self.model_name, generation_config=self.generation_config ) response = model.generate_content(full_prompt) if not response or not response.text: - logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.") + logger.warning( + f"[{self.model_name}] Empty Gemini generate_response. Returning empty." + ) return "" return response.text.strip() except Exception as e: @@ -460,50 +516,60 @@ class GeminiClient(BaseModelClient): return "" def get_conversation_reply( - self, - power_name: str, - conversation_so_far: str, - game_phase: str, + self, + power_name: str, + conversation_so_far: str, + game_phase: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: """ - Produce a single short conversation message from the Gemini model, + Produce a single short conversation message from the Gemini model, given existing conversation context. """ # Similar approach: create a system plus user prompt, then call model.generate_content system_prompt = f"You are playing as {power_name} in this Diplomacy negotiation phase {game_phase}.\n" - user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries) + user_prompt = self.build_conversation_reply( + power_name, conversation_so_far, game_phase, phase_summaries + ) full_prompt = system_prompt + user_prompt try: model = genai.GenerativeModel( - self.model_name, - generation_config=self.generation_config + self.model_name, generation_config=self.generation_config ) response = model.generate_content(full_prompt) if not response or not response.text: - logger.warning(f"[{self.model_name}] Empty Gemini conversation response. Returning empty.") + logger.warning( + f"[{self.model_name}] Empty Gemini conversation response. Returning empty." + ) return "" else: - logger.info(f"[{self.model_name}] Gemini message succesfully generated.") + logger.info( + f"[{self.model_name}] Gemini message succesfully generated." + ) return response.text.strip() except json.JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decode error in conversation: {json_err}") + logger.error( + f"[{self.model_name}] JSON decode error in conversation: {json_err}" + ) return "" except Exception as e: - logger.error(f"[{self.model_name}] Error in Gemini get_conversation_reply: {e}") + logger.error( + f"[{self.model_name}] Error in Gemini get_conversation_reply: {e}" + ) return "" + class DeepSeekClient(BaseModelClient): """ For DeepSeek R1 'deepseek-reasoner' """ + def __init__(self, model_name: str): super().__init__(model_name) self.api_key = os.environ.get("DEEPSEEK_API_KEY") self.client = DeepSeekOpenAI( - api_key=self.api_key, - base_url="https://api.deepseek.com/" + api_key=self.api_key, base_url="https://api.deepseek.com/" ) def generate_response(self, prompt: str) -> str: @@ -513,14 +579,16 @@ class DeepSeekClient(BaseModelClient): model=self.model_name, messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], - stream=False + stream=False, ) logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}") if not response or not response.choices: - logger.warning(f"[{self.model_name}] No valid response in generate_response.") + logger.warning( + f"[{self.model_name}] No valid response in generate_response." + ) return "" content = response.choices[0].message.content.strip() @@ -534,57 +602,79 @@ class DeepSeekClient(BaseModelClient): if json_response["message_type"] == "private": required_fields.append("recipient") if not all(field in json_response for field in required_fields): - logger.error(f"[{self.model_name}] Missing required fields in response: {content}") + logger.error( + f"[{self.model_name}] Missing required fields in response: {content}" + ) return "" return content - except json.JSONDecodeError: - logger.error(f"[{self.model_name}] Response is not valid JSON: {content}") + except JSONDecodeError: + logger.error( + f"[{self.model_name}] Response is not valid JSON: {content}" + ) content = content.replace("'", '"') try: json.loads(content) return content - except: + except JSONDecodeError: return "" except Exception as e: - logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}") + logger.error( + f"[{self.model_name}] Unexpected error in generate_response: {e}" + ) return "" - + def get_conversation_reply( - self, - power_name: str, - conversation_so_far: str, - game_phase: str, + self, + power_name: str, + conversation_so_far: str, + game_phase: str, phase_summaries: Optional[Dict[str, str]] = None, ) -> str: - system_prompt = self.system_prompt_conversation.format(power_name=power_name, game_phase=game_phase) - user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries) - user_prompt += "\n\nPlease provide ONLY a single JSON object as per the examples above." + system_prompt = self.system_prompt_conversation.format( + power_name=power_name, game_phase=game_phase + ) + user_prompt = self.build_conversation_reply( + power_name, conversation_so_far, game_phase, phase_summaries + ) + user_prompt += ( + "\n\nPlease provide ONLY a single JSON object as per the examples above." + ) try: response = self.client.chat.completions.create( model=self.model_name, - messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], - max_completion_tokens=2000 + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + max_completion_tokens=2000, + ) + logger.debug( + f"[{self.model_name}] Raw DeepSeek conversation response:\n{response}" ) - logger.debug(f"[{self.model_name}] Raw DeepSeek conversation response:\n{response}") if not response or not response.choices: - logger.warning(f"[{self.model_name}] No valid choices in conversation reply.") + logger.warning( + f"[{self.model_name}] No valid choices in conversation reply." + ) return "" return response.choices[0].message.content.strip() except json.JSONDecodeError as json_err: - logger.error(f"[{self.model_name}] JSON decode error in conversation: {json_err}") + logger.error( + f"[{self.model_name}] JSON decode error in conversation: {json_err}" + ) return "" except Exception as e: logger.error(f"[{self.model_name}] Unexpected error in conversation: {e}") - return "" + return "" ############################################################################## # 3) Factory to Load Model Client ############################################################################## + def load_model_client(model_id: str) -> BaseModelClient: """ Returns the appropriate LLM client for a given model_id string. @@ -608,13 +698,13 @@ def load_model_client(model_id: str) -> BaseModelClient: # 4) Example Usage in a Diplomacy "main" or Similar ############################################################################## + def assign_models_to_powers(): """ Example usage: define which model each power uses. Return a dict: { power_name: model_id, ... } POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY'] """ - # "RUSSIA": "deepseek-reasoner", deepseek api having issues return { "FRANCE": "gemini-2.0-flash", "GERMANY": "gemini-2.0-flash", @@ -624,7 +714,7 @@ def assign_models_to_powers(): "AUSTRIA": "gemini-2.0-flash", "TURKEY": "gemini-2.0-flash", } - + # return { # "FRANCE": "o3-mini", # "GERMANY": "claude-3-5-sonnet-20241022", @@ -635,12 +725,17 @@ def assign_models_to_powers(): # "TURKEY": "claude-3-5-haiku-20241022", # } + def example_game_loop(game): """ Pseudocode: Integrate with the Diplomacy loop. """ # Suppose we gather all active powers - active_powers = [(p_name, p_obj) for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()] + active_powers = [ + (p_name, p_obj) + for p_name, p_obj in game.powers.items() + if not p_obj.is_eliminated() + ] power_model_mapping = assign_models_to_powers() for power_name, power_obj in active_powers: @@ -658,20 +753,23 @@ def example_game_loop(game): # Then process, etc. game.process() + class LMServiceVersus: """ Optional wrapper class if you want extra control. For example, you could store or reuse clients, etc. """ + def __init__(self): self.power_model_map = assign_models_to_powers() - + def get_orders_for_power(self, game, power_name): model_id = self.power_model_map.get(power_name, "o3-mini") client = load_model_client(model_id) possible_orders = gather_possible_orders(game, power_name) board_state = game.get_state() - return client.get_orders(board_state, power_name, possible_orders) + return client.get_orders(board_state, power_name, possible_orders) + ############################################################################## # 1) Add a method to filter visible messages (near top-level or in BaseModelClient) @@ -684,14 +782,16 @@ def get_visible_messages_for_power(conversation_messages, power_name): for msg in conversation_messages: # GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage if ( - msg['recipient'] == 'ALL' or msg['recipient'] == 'GLOBAL' - or msg['sender'] == power_name - or msg['recipient'] == power_name + msg["recipient"] == "ALL" + or msg["recipient"] == "GLOBAL" + or msg["sender"] == power_name + or msg["recipient"] == power_name ): visible.append(msg) - return visible # already in chronological order if appended that way + return visible # already in chronological order if appended that way + def load_prompt(filename: str) -> str: """Helper to load prompt text from file""" with open(f"./prompts/{filename}", "r") as f: - return f.read().strip() \ No newline at end of file + return f.read().strip() diff --git a/lm_game.py b/lm_game.py index 5d70742..febe44a 100644 --- a/lm_game.py +++ b/lm_game.py @@ -9,8 +9,7 @@ import json from collections import defaultdict # Suppress Gemini/PaLM gRPC warnings -os.environ['GRPC_PYTHON_LOG_LEVEL'] = '40' # ERROR level only -import google.generativeai as genai # Import after setting log level +os.environ["GRPC_PYTHON_LOG_LEVEL"] = "40" # ERROR level only from diplomacy import Game from diplomacy.utils.export import to_saved_game_format @@ -21,7 +20,7 @@ from diplomacy.engine.message import Message, GLOBAL # For concurrency: import concurrent.futures -from lm_service_versus import load_model_client, assign_models_to_powers +from ai_diplomacy.clients import load_model_client, assign_models_to_powers dotenv.load_dotenv() @@ -29,9 +28,10 @@ logger = logging.getLogger(__name__) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", - datefmt="%H:%M:%S" + datefmt="%H:%M:%S", ) + def gather_possible_orders(game, power_name): """ Returns a dictionary mapping each orderable location to the list of valid orders. @@ -44,6 +44,7 @@ def gather_possible_orders(game, power_name): result[loc] = all_possible.get(loc, []) return result + def conduct_negotiations(game, model_error_stats, max_rounds=10): """ Conducts a round-robin conversation among all non-eliminated powers. @@ -66,11 +67,15 @@ def conduct_negotiations(game, model_error_stats, max_rounds=10): visible_messages = [] for msg in conversation_messages: # Include if message is global or if power is sender/recipient - if msg['recipient'] == GLOBAL or msg['sender'] == power_name or msg['recipient'] == power_name: + if ( + msg["recipient"] == GLOBAL + or msg["sender"] == power_name + or msg["recipient"] == power_name + ): visible_messages.append( f"{msg['sender']} to {msg['recipient']}: {msg['content']}" ) - + conversation_so_far = "\n".join(visible_messages) # Add few-shot example for message format @@ -99,60 +104,69 @@ If you have your plan already figured out, you can just send a public '.' to ind power_name=power_name, conversation_so_far=conversation_so_far + "\n" + few_shot_example, game_phase=game.current_short_phase, - phase_summaries=game.phase_summaries + phase_summaries=game.phase_summaries, ) if new_message: try: # Parse the JSON response # Find the JSON block between curly braces - json_match = re.search(r'\{[^}]+\}', new_message) + json_match = re.search(r"\{[^}]+\}", new_message) if json_match: message_data = json.loads(json_match.group(0)) - + # Extract message details - message_type = message_data.get('message_type', 'global') - content = message_data.get('content', '').strip() - recipient = message_data.get('recipient', GLOBAL) - + message_type = message_data.get("message_type", "global") + content = message_data.get("content", "").strip() + recipient = message_data.get("recipient", GLOBAL) + # Validate recipient if private message - if message_type == 'private' and recipient not in active_powers: - logger.warning(f"Invalid recipient {recipient} for private message, defaulting to GLOBAL") + if message_type == "private" and recipient not in active_powers: + logger.warning( + f"Invalid recipient {recipient} for private message, defaulting to GLOBAL" + ) recipient = GLOBAL - + # For private messages, ensure recipient is specified - if message_type == 'private' and recipient == GLOBAL: - logger.warning("Private message without recipient specified, defaulting to GLOBAL") - + if message_type == "private" and recipient == GLOBAL: + logger.warning( + "Private message without recipient specified, defaulting to GLOBAL" + ) + # Log for debugging - logger.info(f"Power {power_name} sends {message_type} message to {recipient}") - + logger.info( + f"Power {power_name} sends {message_type} message to {recipient}" + ) + # Keep local record for building future conversation context - conversation_messages.append({ - "sender": power_name, - "recipient": recipient, - "content": content - }) + conversation_messages.append( + { + "sender": power_name, + "recipient": recipient, + "content": content, + } + ) # Create an official message in the Diplomacy engine diplo_message = Message( phase=game.current_short_phase, sender=power_name, recipient=recipient, - message=content + message=content, ) game.add_message(diplo_message) - except (json.JSONDecodeError, AttributeError) as e: logger.error(f"Failed to parse message from {power_name}: {e}") # Increment conversation parse error model_id = game.power_model_map.get(power_name, "unknown") model_error_stats[model_id]["conversation_errors"] += 1 continue + logger.info("Negotiation phase complete.") return conversation_messages + def my_summary_callback(system_prompt, user_prompt): # e.g., route to your desired model: client = load_model_client("o3-mini") @@ -160,15 +174,18 @@ def my_summary_callback(system_prompt, user_prompt): # Pseudo-code for generating a response: return client.generate_response(combined_prompt) -def get_valid_orders_with_retry(game, - client, - board_state, - power_name, - possible_orders, - conversation_text_for_orders, - phase_summaries, - model_error_stats, - max_retries=3): + +def get_valid_orders_with_retry( + game, + client, + board_state, + power_name, + possible_orders, + conversation_text_for_orders, + phase_summaries, + model_error_stats, + max_retries=3, +): """ Tries up to 'max_retries' to generate and validate orders. If invalid, we append the error feedback to the conversation @@ -190,11 +207,11 @@ def get_valid_orders_with_retry(game, possible_orders=possible_orders, conversation_text=augmented_conversation_text, phase_summaries=phase_summaries, - model_error_stats=model_error_stats + model_error_stats=model_error_stats, ) - - print(f'orders: {orders}') - + + print(f"orders: {orders}") + # Validate each order invalid_info = [] for move in orders: @@ -206,13 +223,15 @@ def get_valid_orders_with_retry(game, ) continue unit = " ".join(tokens[:2]) # e.g. "A PAR" - order_part = tokens[2] # e.g. "H" or "S A MAR" + order_part = tokens[2] # e.g. "H" or "S A MAR" # Use the internal game validation method - if order_part == 'B': - validity = 1 # hack because game._valid_order doesn't support 'B' - else: - validity = game._valid_order(game.powers[power_name], unit, order_part, report=1) + if order_part == "B": + validity = 1 # hack because game._valid_order doesn't support 'B' + else: + validity = game._valid_order( + game.powers[power_name], unit, order_part, report=1 + ) if validity != 1: invalid_info.append( f"Order '{move}' returned validity={validity}. (None/-1=invalid, 0=partial, 1=valid)" @@ -224,7 +243,7 @@ def get_valid_orders_with_retry(game, else: # Build feedback for the next retry error_feedback = ( - f"Attempt {attempt+1}/{max_retries} had invalid orders:\n" + f"Attempt {attempt + 1}/{max_retries} had invalid orders:\n" + "\n".join(invalid_info) ) @@ -235,17 +254,21 @@ def get_valid_orders_with_retry(game, fallback = client.fallback_orders(possible_orders) return fallback + def main(): - logger.info("Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!") + logger.info( + "Starting a new Diplomacy game for testing with multiple LLMs, now concurrent!" + ) start_whole = time.time() - from collections import defaultdict - model_error_stats = defaultdict(lambda: {"conversation_errors": 0, "order_decoding_errors": 0}) + model_error_stats = defaultdict( + lambda: {"conversation_errors": 0, "order_decoding_errors": 0} + ) # Create a fresh Diplomacy game game = Game() # Ensure game has phase_summaries = {} - if not hasattr(game, 'phase_summaries'): + if not hasattr(game, "phase_summaries"): game.phase_summaries = {} # For storing results in a unique subfolder @@ -260,12 +283,14 @@ def main(): stats_file_path = f"{result_folder}/error_stats.json" game.power_model_map = assign_models_to_powers() - max_year = 1901 + max_year = 1910 while not game.is_game_done: phase_start = time.time() current_phase = game.get_current_phase() - logger.info(f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)") + logger.info( + f"PHASE: {current_phase} (time so far: {phase_start - start_whole:.2f}s)" + ) # DEBUG: Print the short phase to confirm logger.info(f"DEBUG: current_short_phase is '{game.current_short_phase}'") @@ -280,24 +305,31 @@ def main(): # Use endswith("M") for movement phases (like F1901M, S1902M) if game.current_short_phase.endswith("M"): logger.info("Starting negotiation phase block...") - conversation_messages = conduct_negotiations(game, model_error_stats, max_rounds=10) + conversation_messages = conduct_negotiations( + game, model_error_stats, max_rounds=10 + ) else: # If we have no conversation_messages in phases that are not Movement (e.g. Retreat/Build) conversation_messages = [] - conversation_text_for_orders = "\n".join([ - f"{msg['sender']} to {msg['recipient']}: {msg['content']}" - for msg in conversation_messages - ]) + conversation_text_for_orders = "\n".join( + [ + f"{msg['sender']} to {msg['recipient']}: {msg['content']}" + for msg in conversation_messages + ] + ) # Gather orders from each power concurrently active_powers = [ - (p_name, p_obj) for p_name, p_obj in game.powers.items() + (p_name, p_obj) + for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated() ] # Then proceed with concurrent order generation - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(active_powers) + ) as executor: futures = {} for power_name, _ in active_powers: model_id = game.power_model_map.get(power_name, "o3-mini") @@ -319,10 +351,12 @@ def main(): conversation_text_for_orders, # existing conversation text game.phase_summaries, model_error_stats, - 3 # max_retries + 3, # max_retries ) futures[future] = power_name - logger.debug(f"Submitted get_valid_orders_with_retry task for {power_name}.") + logger.debug( + f"Submitted get_valid_orders_with_retry task for {power_name}." + ) for future in concurrent.futures.as_completed(futures): p_name = futures[future] @@ -331,7 +365,9 @@ def main(): logger.debug(f"Validated orders for {p_name}: {orders}") if orders: game.set_orders(p_name, orders) - logger.debug(f"Set orders for {p_name} in {game.current_short_phase}: {orders}") + logger.debug( + f"Set orders for {p_name} in {game.current_short_phase}: {orders}" + ) else: logger.debug(f"No valid orders returned for {p_name}.") except Exception as exc: @@ -346,7 +382,9 @@ def main(): # Print in pretty ASCII format border = "=" * 80 - logger.info(f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}") + logger.info( + f"{border}\nPHASE SUMMARY for {phase_data.name}:\n{summary_text}\n{border}" + ) # Write to unique game_manifesto in the timestamped folder with open(manifesto_path, "a") as f: @@ -369,16 +407,18 @@ def main(): to_saved_game_format(game, output_path=output_path) else: logger.info("Game file already exists, saving with unique filename.") - output_path = f'{output_path}_{time.strftime("%Y%m%d_%H%M%S")}.json' + output_path = f"{output_path}_{time.strftime('%Y%m%d_%H%M%S')}.json" to_saved_game_format(game, output_path=output_path) # Dump our error stats to JSON import json + with open(stats_file_path, "w") as stats_f: json.dump(model_error_stats, stats_f, indent=2) logger.info(f"Saved game data, manifesto, and error stats in: {result_folder}") logger.info("Done.") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..82cc0a7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.ruff] +exclude = [ + "diplomacy" +] diff --git a/randomgame.py b/randomgame.py index a1cc4de..f80f170 100644 --- a/randomgame.py +++ b/randomgame.py @@ -6,17 +6,19 @@ from diplomacy.utils.export import to_saved_game_format # Alternatively, a map_name can be specified as an argument. e.g. Game(map_name='pure') game = Game() while not game.is_game_done: - # Getting the list of possible orders for all locations possible_orders = game.get_all_possible_orders() # For each power, randomly sampling a valid order for power_name, power in game.powers.items(): - power_orders = [random.choice(possible_orders[loc]) for loc in game.get_orderable_locations(power_name) - if possible_orders[loc]] + power_orders = [ + random.choice(possible_orders[loc]) + for loc in game.get_orderable_locations(power_name) + if possible_orders[loc] + ] game.set_orders(power_name, power_orders) - print(f'{power_name} orders: {power_orders}') + print(f"{power_name} orders: {power_orders}") # Messages can be sent locally with game.add_message # e.g. game.add_message(Message(sender='FRANCE', @@ -30,4 +32,4 @@ while not game.is_game_done: # Exporting the game to disk to visualize (game is appended to file) # Alternatively, we can do >> file.write(json.dumps(to_saved_game_format(game))) -to_saved_game_format(game, output_path='game.json') \ No newline at end of file +to_saved_game_format(game, output_path="game.json") diff --git a/setup.py b/setup.py index 29997ad..be8413f 100644 --- a/setup.py +++ b/setup.py @@ -14,54 +14,64 @@ # You should have received a copy of the GNU Affero General Public License along # with this program. If not, see . # ============================================================================== -""" Package installer """ +"""Package installer""" + import os import sys from setuptools import setup, find_packages # Import the current version -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'diplomacy')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "diplomacy")) from version import PACKAGE_VERSION # Requiring python 3.5+. # To simplify code for Tornado coroutines return statements, we don't support Python 3.4 # ( more info here: http://www.tornadoweb.org/en/stable/guide/coroutines.html#coroutines ). if (sys.version_info.major, sys.version_info.minor) <= (3, 4): - print("This package is only compatible with Python 3.5+, but you are running Python {}.{}." - .format(sys.version_info.major, sys.version_info.minor)) + print( + "This package is only compatible with Python 3.5+, but you are running Python {}.{}.".format( + sys.version_info.major, sys.version_info.minor + ) + ) # ------------------------------------ # Configuration -setup(name='diplomacy', - version=PACKAGE_VERSION, - description='Diplomacy: DATC-Compliant Game Engine with Web Interface', - long_description=open(os.path.join(os.path.dirname(__file__), 'README.md')).read(), - long_description_content_type='text/markdown', - url='https://github.com/diplomacy/diplomacy', - author='Philip Paquette', - author_email='pcpaquette@gmail.com', - packages=find_packages(), - keywords='diplomacy diplomacy-game game negotiation', - python_requires='>=3.5', - include_package_data=True, - install_requires=[ - 'bcrypt', - 'coloredlogs', - 'python-dateutil', - 'pytz', - 'tornado>=5.0', - 'tqdm', - 'ujson', - ], - tests_require=['pytest'], - classifiers=['License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Topic :: Games/Entertainment :: Board Games'], - project_urls={'Bug Reports': 'https://github.com/diplomacy/diplomacy/issues', - 'Documentation': 'https://diplomacy.readthedocs.io/', - 'Source': 'https://github.com/diplomacy/diplomacy/'}) +setup( + name="diplomacy", + version=PACKAGE_VERSION, + description="Diplomacy: DATC-Compliant Game Engine with Web Interface", + long_description=open(os.path.join(os.path.dirname(__file__), "README.md")).read(), + long_description_content_type="text/markdown", + url="https://github.com/diplomacy/diplomacy", + author="Philip Paquette", + author_email="pcpaquette@gmail.com", + packages=find_packages(), + keywords="diplomacy diplomacy-game game negotiation", + python_requires=">=3.5", + include_package_data=True, + install_requires=[ + "bcrypt", + "coloredlogs", + "python-dateutil", + "pytz", + "tornado>=5.0", + "tqdm", + "ujson", + ], + tests_require=["pytest"], + classifiers=[ + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Topic :: Games/Entertainment :: Board Games", + ], + project_urls={ + "Bug Reports": "https://github.com/diplomacy/diplomacy/issues", + "Documentation": "https://diplomacy.readthedocs.io/", + "Source": "https://github.com/diplomacy/diplomacy/", + }, +) # ------------------------------------