mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-04-19 12:58:09 +00:00
696 lines
No EOL
28 KiB
Python
696 lines
No EOL
28 KiB
Python
import os
|
|
import json
|
|
import re
|
|
import logging
|
|
import ast
|
|
|
|
from typing import List, Dict, Optional
|
|
from dotenv import load_dotenv
|
|
|
|
# Anthropics
|
|
import anthropic
|
|
|
|
# Google Generative AI
|
|
# Set gemini to more verbose
|
|
os.environ['GRPC_PYTHON_LOG_LEVEL'] = '10'
|
|
import google.generativeai as genai # Import after setting log level
|
|
|
|
# DeepSeek
|
|
from openai import OpenAI as DeepSeekOpenAI
|
|
|
|
# set logger back to just info
|
|
logger = logging.getLogger('lm_service_versus')
|
|
logger.setLevel(logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
load_dotenv()
|
|
|
|
##############################################################################
|
|
# 1) Base Interface
|
|
##############################################################################
|
|
class BaseModelClient:
|
|
"""
|
|
Base interface for any LLM client we want to plug in.
|
|
Each must provide:
|
|
- generate_response(prompt: str) -> str
|
|
- get_orders(board_state, power_name, possible_orders) -> List[str]
|
|
- get_conversation_reply(power_name, conversation_so_far, game_phase) -> str
|
|
"""
|
|
|
|
def __init__(self, model_name: str):
|
|
self.model_name = model_name
|
|
self.system_prompt_response = load_prompt("system_prompt_response.txt")
|
|
self.system_prompt_conversation = load_prompt("system_prompt_conversation.txt")
|
|
def generate_response(self, prompt: str) -> str:
|
|
"""
|
|
Returns a raw string from the LLM.
|
|
Subclasses override this.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement generate_response().")
|
|
|
|
|
|
def build_prompt(
|
|
self,
|
|
board_state,
|
|
power_name: str,
|
|
possible_orders: Dict[str, List[str]],
|
|
conversation_text: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
"""
|
|
Unified prompt approach: incorporate conversation and 'PARSABLE OUTPUT' requirements.
|
|
"""
|
|
# Get our units and centers
|
|
units_info = board_state["units"].get(power_name, [])
|
|
centers_info = board_state["centers"].get(power_name, [])
|
|
|
|
# Get the current phase
|
|
year_phase = board_state["phase"] # e.g. 'S1901M'
|
|
|
|
# Get enemy units and centers and label them for each power
|
|
enemy_units = {}
|
|
enemy_centers = {}
|
|
for power, info in board_state["units"].items():
|
|
if power != power_name:
|
|
enemy_units[power] = info
|
|
enemy_centers[power] = board_state["centers"].get(power, [])
|
|
|
|
|
|
summary = (
|
|
f"Power: {power_name}\n"
|
|
f"Current Phase: {year_phase}\n"
|
|
f"Enemy Units: {enemy_units}\n"
|
|
f"Enemy Centers: {enemy_centers}\n"
|
|
f"Your Units: {units_info}\n"
|
|
f"Your Centers: {centers_info}\n"
|
|
f"Possible Orders:\n"
|
|
)
|
|
for loc, orders in possible_orders.items():
|
|
summary += f" {loc}: {orders}\n"
|
|
|
|
# Load prompts
|
|
few_shot_example = load_prompt("few_shot_example.txt")
|
|
instructions = load_prompt("instructions.txt")
|
|
|
|
# 1) Prepare a block of text for the phase_summaries
|
|
if phase_summaries:
|
|
historical_summaries = "\nPAST PHASE SUMMARIES:\n"
|
|
for phase_key, summary_txt in phase_summaries.items():
|
|
historical_summaries += f"\nPHASE {phase_key}:\n{summary_txt}\n"
|
|
else:
|
|
historical_summaries = "\n(No historical summaries provided)\n"
|
|
|
|
prompt = (
|
|
"Relevant Conversation:\n" + conversation_text + "\n\n"
|
|
+ "Historical Summaries:\n" + historical_summaries + "\n\n"
|
|
+ summary + few_shot_example + "\n"
|
|
+ instructions
|
|
)
|
|
return prompt
|
|
|
|
def get_orders(
|
|
self,
|
|
board_state,
|
|
power_name: str,
|
|
possible_orders: Dict[str, List[str]],
|
|
conversation_text: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
model_error_stats=None # New optional param
|
|
) -> List[str]:
|
|
"""
|
|
1) Builds the prompt with conversation context if available
|
|
2) Calls LLM
|
|
3) Parses JSON block
|
|
"""
|
|
prompt = self.build_prompt(board_state, power_name, possible_orders, conversation_text, phase_summaries)
|
|
|
|
raw_response = ""
|
|
|
|
try:
|
|
raw_response = self.generate_response(prompt)
|
|
logger.info(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
|
|
|
# Attempt to parse the final "orders" from the LLM
|
|
move_list = self._extract_moves(raw_response, power_name)
|
|
|
|
if not move_list:
|
|
import pdb; pdb.set_trace()
|
|
|
|
logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
|
|
if model_error_stats is not None:
|
|
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
|
return self.fallback_orders(possible_orders)
|
|
# Validate or fallback
|
|
validated_moves = self._validate_orders(move_list, possible_orders)
|
|
return validated_moves
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] LLM error for {power_name}: {e}")
|
|
return self.fallback_orders(possible_orders)
|
|
|
|
def _extract_moves(self, raw_response: str, power_name: str) -> Optional[List[str]]:
|
|
"""
|
|
Attempt multiple parse strategies to find JSON array of moves.
|
|
|
|
1. Regex for PARSABLE OUTPUT lines.
|
|
2. If that fails, also look for fenced code blocks with { ... }.
|
|
3. Attempt bracket-based fallback if needed.
|
|
|
|
Returns a list of move strings or None if everything fails.
|
|
"""
|
|
# 1) Regex for "PARSABLE OUTPUT:{...}"
|
|
pattern = r"PARSABLE OUTPUT:\s*(\{[\s\S]*\})"
|
|
matches = re.search(pattern, raw_response, re.DOTALL)
|
|
|
|
if not matches:
|
|
# Some LLMs might not put the colon or might have triple backtick fences.
|
|
logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
|
|
|
|
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
|
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
|
matches = re.search(pattern_alt, raw_response, re.DOTALL)
|
|
|
|
if not matches:
|
|
logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying triple-backtick code fences.")
|
|
|
|
# 2) If still no match, check for triple-backtick code fences containing JSON
|
|
if not matches:
|
|
code_fence_pattern = r"```json\s*\{(.*?)\}\s*```"
|
|
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
|
if matches:
|
|
logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
|
|
|
|
# 3) Attempt to parse JSON if we found anything
|
|
json_text = None
|
|
if matches:
|
|
# Add braces back around the captured group
|
|
if matches.group(1).strip().startswith(r"{{"):
|
|
json_text = matches.group(1).strip()[1:-1]
|
|
elif matches.group(1).strip().startswith(r"{"):
|
|
json_text = matches.group(1).strip()
|
|
else:
|
|
json_text = "{%s}" % matches.group(1).strip
|
|
|
|
json_text = json_text.strip()
|
|
|
|
if not json_text:
|
|
logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
|
|
return None
|
|
|
|
# 3a) Try JSON loading
|
|
try:
|
|
data = json.loads(json_text)
|
|
return data.get("orders", None)
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying bracket fallback.")
|
|
|
|
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
|
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
|
# We only do this if we see something like "orders": ...
|
|
bracket_pattern = r'["\']orders["\']\s*:\s*\[([^\]]*)\]'
|
|
bracket_match = re.search(bracket_pattern, json_text, re.DOTALL)
|
|
if bracket_match:
|
|
try:
|
|
raw_list_str = "[" + bracket_match.group(1).strip() + "]"
|
|
moves = ast.literal_eval(raw_list_str)
|
|
if isinstance(moves, list):
|
|
return moves
|
|
except Exception as e2:
|
|
logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
|
|
|
|
# If all attempts failed
|
|
return None
|
|
|
|
def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> List[str]:
|
|
"""
|
|
Filter out invalid moves, fill missing with HOLD, else fallback.
|
|
"""
|
|
logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}")
|
|
validated = []
|
|
used_locs = set()
|
|
|
|
if not isinstance(moves, list):
|
|
logger.debug(f"[{self.model_name}] Moves not a list, fallback.")
|
|
return self.fallback_orders(possible_orders)
|
|
|
|
for move in moves:
|
|
move_str = move.strip()
|
|
# Check if it's in possible orders
|
|
if any(move_str in loc_orders for loc_orders in possible_orders.values()):
|
|
validated.append(move_str)
|
|
parts = move_str.split()
|
|
if len(parts) >= 2:
|
|
used_locs.add(parts[1][:3])
|
|
else:
|
|
logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}")
|
|
|
|
# Fill missing with hold
|
|
for loc, orders_list in possible_orders.items():
|
|
if loc not in used_locs and orders_list:
|
|
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
|
validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
|
|
|
|
if not validated:
|
|
logger.warning(f"[{self.model_name}] All moves invalid, fallback.")
|
|
return self.fallback_orders(possible_orders)
|
|
|
|
logger.debug(f"[{self.model_name}] Validated moves: {validated}")
|
|
return validated
|
|
|
|
def fallback_orders(self, possible_orders: Dict[str, List[str]]) -> List[str]:
|
|
"""
|
|
Just picks HOLD if possible, else first option.
|
|
"""
|
|
fallback = []
|
|
for loc, orders_list in possible_orders.items():
|
|
if orders_list:
|
|
holds = [o for o in orders_list if o.endswith("H")]
|
|
fallback.append(holds[0] if holds else orders_list[0])
|
|
return fallback
|
|
|
|
def build_conversation_reply(
|
|
self,
|
|
power_name: str,
|
|
conversation_so_far: str,
|
|
game_phase: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
"""
|
|
Produce a single message in valid JSON with 'message_type' etc.
|
|
"""
|
|
return load_prompt("build_conversation_reply.txt").format(
|
|
power_name=power_name,
|
|
game_phase=game_phase,
|
|
phase_summaries=phase_summaries,
|
|
conversation_so_far=conversation_so_far
|
|
)
|
|
|
|
def generate_conversation_reply(self, power_name: str, conversation_so_far: str, game_phase: str) -> str:
|
|
"""
|
|
Overwritten by subclasses.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement generate_conversation_reply().")
|
|
|
|
##############################################################################
|
|
# 2) Concrete Implementations
|
|
##############################################################################
|
|
|
|
class OpenAIClient(BaseModelClient):
|
|
"""
|
|
For 'o3-mini', 'gpt-4o', or other OpenAI model calls.
|
|
"""
|
|
def __init__(self, model_name: str):
|
|
super().__init__(model_name)
|
|
from openai import OpenAI # Import the new client
|
|
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
|
def generate_response(self, prompt: str) -> str:
|
|
# Updated to new API format
|
|
system_prompt = self.system_prompt_response
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
)
|
|
if not response or not hasattr(response, "choices") or not response.choices:
|
|
logger.warning(f"[{self.model_name}] Empty or invalid result in generate_response. Returning empty.")
|
|
return ""
|
|
return response.choices[0].message.content.strip()
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
|
return ""
|
|
|
|
def get_conversation_reply(
|
|
self,
|
|
power_name: str,
|
|
conversation_so_far: str,
|
|
game_phase: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
"""
|
|
Produces a single message with the appropriate JSON format.
|
|
"""
|
|
import json
|
|
from json.decoder import JSONDecodeError
|
|
# load the system prompt but formatted with the power name and game phase
|
|
system_prompt = self.system_prompt_conversation.format(power_name=power_name, game_phase=game_phase)
|
|
conversation_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries)
|
|
|
|
try:
|
|
# Perform the request
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": conversation_prompt}
|
|
],
|
|
max_completion_tokens=2000
|
|
)
|
|
|
|
# If there's no valid response or choices, return empty
|
|
if not response or not hasattr(response, "choices") or not response.choices:
|
|
logger.warning(f"[{self.model_name}] Empty or invalid response for {power_name}. Returning empty.")
|
|
return ""
|
|
|
|
# Attempt to parse the content (OpenAI library usually does this, but we add a safety net)
|
|
return response.choices[0].message.content.strip()
|
|
|
|
except JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decoding failed for {power_name}: {json_err}")
|
|
return "" # Fallback
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error for {power_name}: {e}")
|
|
return ""
|
|
|
|
class ClaudeClient(BaseModelClient):
|
|
"""
|
|
For 'claude-3-5-sonnet-20241022', 'claude-3-5-haiku-20241022', etc.
|
|
"""
|
|
def __init__(self, model_name: str):
|
|
super().__init__(model_name)
|
|
self.client = anthropic.Anthropic(
|
|
api_key=os.environ.get("ANTHROPIC_API_KEY")
|
|
)
|
|
|
|
def generate_response(self, prompt: str) -> str:
|
|
system_prompt = self.system_prompt_response
|
|
# Updated Claude messages format
|
|
try:
|
|
response = self.client.messages.create(
|
|
model=self.model_name,
|
|
max_tokens=2000,
|
|
system=system_prompt, # system is now a top-level parameter
|
|
messages=[
|
|
{"role": "user", "content": prompt}
|
|
]
|
|
)
|
|
if not response.content:
|
|
logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
|
|
return ""
|
|
return response.content[0].text.strip() if response.content else ""
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
|
return ""
|
|
|
|
def get_conversation_reply(
|
|
self,
|
|
power_name: str,
|
|
conversation_so_far: str,
|
|
game_phase: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
system_prompt = f"You are playing as {power_name} in this Diplomacy negotiation phase {game_phase}."
|
|
user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries)
|
|
try:
|
|
response = self.client.messages.create(
|
|
model=self.model_name,
|
|
system=system_prompt,
|
|
messages=[{"role": "user", "content": user_prompt}],
|
|
max_tokens=2000
|
|
)
|
|
if not response.content:
|
|
logger.warning(f"[{self.model_name}] No content in Claude conversation. Returning empty.")
|
|
return ""
|
|
return response.content[0].text.strip()
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decoding failed in conversation: {json_err}")
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error in conversation: {e}")
|
|
return ""
|
|
|
|
class GeminiClient(BaseModelClient):
|
|
"""
|
|
For 'gemini-1.5-flash' or other Google Generative AI models.
|
|
"""
|
|
def __init__(self, model_name: str):
|
|
super().__init__(model_name)
|
|
genai.configure(api_key=os.environ.get("GEMINI_API_KEY"))
|
|
self.generation_config = {
|
|
"temperature": 0.7,
|
|
"max_output_tokens": 2000,
|
|
}
|
|
|
|
def generate_response(self, prompt: str) -> str:
|
|
system_prompt = self.system_prompt_response
|
|
full_prompt = system_prompt + prompt
|
|
|
|
try:
|
|
model = genai.GenerativeModel(
|
|
self.model_name,
|
|
generation_config=self.generation_config
|
|
)
|
|
response = model.generate_content(full_prompt)
|
|
if not response or not response.text:
|
|
logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
|
|
return ""
|
|
return response.text.strip()
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Error in Gemini generate_response: {e}")
|
|
return ""
|
|
|
|
def get_conversation_reply(
|
|
self,
|
|
power_name: str,
|
|
conversation_so_far: str,
|
|
game_phase: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
"""
|
|
Produce a single short conversation message from the Gemini model,
|
|
given existing conversation context.
|
|
"""
|
|
# Similar approach: create a system plus user prompt, then call model.generate_content
|
|
system_prompt = f"You are playing as {power_name} in this Diplomacy negotiation phase {game_phase}.\n"
|
|
user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries)
|
|
full_prompt = system_prompt + user_prompt
|
|
|
|
try:
|
|
model = genai.GenerativeModel(
|
|
self.model_name,
|
|
generation_config=self.generation_config
|
|
)
|
|
response = model.generate_content(full_prompt)
|
|
if not response or not response.text:
|
|
logger.warning(f"[{self.model_name}] Empty Gemini conversation response. Returning empty.")
|
|
return ""
|
|
else:
|
|
logger.info(f"[{self.model_name}] Gemini message succesfully generated.")
|
|
return response.text.strip()
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decode error in conversation: {json_err}")
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Error in Gemini get_conversation_reply: {e}")
|
|
return ""
|
|
|
|
class DeepSeekClient(BaseModelClient):
|
|
"""
|
|
For DeepSeek R1 'deepseek-reasoner'
|
|
"""
|
|
def __init__(self, model_name: str):
|
|
super().__init__(model_name)
|
|
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
|
self.client = DeepSeekOpenAI(
|
|
api_key=self.api_key,
|
|
base_url="https://api.deepseek.com/"
|
|
)
|
|
|
|
def generate_response(self, prompt: str) -> str:
|
|
system_prompt = self.system_prompt_response
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
stream=False
|
|
)
|
|
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
|
|
|
if not response or not response.choices:
|
|
logger.warning(f"[{self.model_name}] No valid response in generate_response.")
|
|
return ""
|
|
|
|
content = response.choices[0].message.content.strip()
|
|
if not content:
|
|
logger.warning(f"[{self.model_name}] DeepSeek returned empty content.")
|
|
return ""
|
|
|
|
try:
|
|
json_response = json.loads(content)
|
|
required_fields = ["message_type", "content"]
|
|
if json_response["message_type"] == "private":
|
|
required_fields.append("recipient")
|
|
if not all(field in json_response for field in required_fields):
|
|
logger.error(f"[{self.model_name}] Missing required fields in response: {content}")
|
|
return ""
|
|
return content
|
|
except json.JSONDecodeError:
|
|
logger.error(f"[{self.model_name}] Response is not valid JSON: {content}")
|
|
content = content.replace("'", '"')
|
|
try:
|
|
json.loads(content)
|
|
return content
|
|
except:
|
|
return ""
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
|
return ""
|
|
|
|
def get_conversation_reply(
|
|
self,
|
|
power_name: str,
|
|
conversation_so_far: str,
|
|
game_phase: str,
|
|
phase_summaries: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
system_prompt = self.system_prompt_conversation.format(power_name=power_name, game_phase=game_phase)
|
|
user_prompt = self.build_conversation_reply(power_name, conversation_so_far, game_phase, phase_summaries)
|
|
user_prompt += "\n\nPlease provide ONLY a single JSON object as per the examples above."
|
|
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
|
max_completion_tokens=2000
|
|
)
|
|
logger.debug(f"[{self.model_name}] Raw DeepSeek conversation response:\n{response}")
|
|
|
|
if not response or not response.choices:
|
|
logger.warning(f"[{self.model_name}] No valid choices in conversation reply.")
|
|
return ""
|
|
return response.choices[0].message.content.strip()
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"[{self.model_name}] JSON decode error in conversation: {json_err}")
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"[{self.model_name}] Unexpected error in conversation: {e}")
|
|
return ""
|
|
|
|
|
|
##############################################################################
|
|
# 3) Factory to Load Model Client
|
|
##############################################################################
|
|
|
|
def load_model_client(model_id: str) -> BaseModelClient:
|
|
"""
|
|
Returns the appropriate LLM client for a given model_id string.
|
|
Example usage:
|
|
client = load_model_client("claude-3-5-sonnet-20241022")
|
|
"""
|
|
# Basic pattern matching or direct mapping
|
|
lower_id = model_id.lower()
|
|
if "claude" in lower_id:
|
|
return ClaudeClient(model_id)
|
|
elif "gemini" in lower_id:
|
|
return GeminiClient(model_id)
|
|
elif "deepseek" in lower_id:
|
|
return DeepSeekClient(model_id)
|
|
else:
|
|
# Default to OpenAI
|
|
return OpenAIClient(model_id)
|
|
|
|
|
|
##############################################################################
|
|
# 4) Example Usage in a Diplomacy "main" or Similar
|
|
##############################################################################
|
|
|
|
def assign_models_to_powers():
|
|
"""
|
|
Example usage: define which model each power uses.
|
|
Return a dict: { power_name: model_id, ... }
|
|
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
|
|
"""
|
|
# "RUSSIA": "deepseek-reasoner", deepseek api having issues
|
|
return {
|
|
"FRANCE": "claude-3-5-haiku-20241022",
|
|
"GERMANY": "claude-3-5-haiku-20241022",
|
|
"ENGLAND": "claude-3-5-haiku-20241022",
|
|
"RUSSIA": "claude-3-5-haiku-20241022",
|
|
"ITALY": "claude-3-5-haiku-20241022",
|
|
"AUSTRIA": "claude-3-5-haiku-20241022",
|
|
"TURKEY": "claude-3-5-haiku-20241022",
|
|
}
|
|
|
|
# return {
|
|
# "FRANCE": "o3-mini",
|
|
# "GERMANY": "claude-3-5-sonnet-20241022",
|
|
# "ENGLAND": "gemini-2.0-flash",
|
|
# "RUSSIA": "gemini-2.0-flash-lite-preview-02-05",
|
|
# "ITALY": "gpt-4o",
|
|
# "AUSTRIA": "gpt-4o-mini",
|
|
# "TURKEY": "claude-3-5-haiku-20241022",
|
|
# }
|
|
|
|
def example_game_loop(game):
|
|
"""
|
|
Pseudocode: Integrate with the Diplomacy loop.
|
|
"""
|
|
# Suppose we gather all active powers
|
|
active_powers = [(p_name, p_obj) for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
|
power_model_mapping = assign_models_to_powers()
|
|
|
|
for power_name, power_obj in active_powers:
|
|
model_id = power_model_mapping.get(power_name, "o3-mini")
|
|
client = load_model_client(model_id)
|
|
|
|
# Get possible orders from the game
|
|
possible_orders = game.get_all_possible_orders()
|
|
board_state = game.get_state()
|
|
|
|
# Get orders from the client
|
|
orders = client.get_orders(board_state, power_name, possible_orders)
|
|
game.set_orders(power_name, orders)
|
|
|
|
# Then process, etc.
|
|
game.process()
|
|
|
|
class LMServiceVersus:
|
|
"""
|
|
Optional wrapper class if you want extra control.
|
|
For example, you could store or reuse clients, etc.
|
|
"""
|
|
def __init__(self):
|
|
self.power_model_map = assign_models_to_powers()
|
|
|
|
def get_orders_for_power(self, game, power_name):
|
|
model_id = self.power_model_map.get(power_name, "o3-mini")
|
|
client = load_model_client(model_id)
|
|
possible_orders = gather_possible_orders(game, power_name)
|
|
board_state = game.get_state()
|
|
return client.get_orders(board_state, power_name, possible_orders)
|
|
|
|
##############################################################################
|
|
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
|
##############################################################################
|
|
def get_visible_messages_for_power(conversation_messages, power_name):
|
|
"""
|
|
Returns a chronological subset of conversation_messages that power_name can legitimately see.
|
|
"""
|
|
visible = []
|
|
for msg in conversation_messages:
|
|
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
|
if (
|
|
msg['recipient'] == 'ALL' or msg['recipient'] == 'GLOBAL'
|
|
or msg['sender'] == power_name
|
|
or msg['recipient'] == power_name
|
|
):
|
|
visible.append(msg)
|
|
return visible # already in chronological order if appended that way
|
|
|
|
def load_prompt(filename: str) -> str:
|
|
"""Helper to load prompt text from file"""
|
|
with open(f"./prompts/{filename}", "r") as f:
|
|
return f.read().strip() |