mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-05-02 17:46:00 +00:00
Ruff format and check
This commit is contained in:
parent
aadedd6512
commit
ad817695ef
13 changed files with 738 additions and 973 deletions
|
|
@ -10,7 +10,7 @@ import ast
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||||
from .clients import BaseModelClient, load_model_client
|
from .clients import BaseModelClient
|
||||||
|
|
||||||
# Import load_prompt and the new logging wrapper from utils
|
# Import load_prompt and the new logging wrapper from utils
|
||||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
||||||
|
|
@ -336,7 +336,7 @@ class DiplomacyAgent:
|
||||||
)
|
)
|
||||||
# If even the last resort doesn't give a dict, return empty dict
|
# If even the last resort doesn't give a dict, return empty dict
|
||||||
return {}
|
return {}
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
@ -965,7 +965,7 @@ class DiplomacyAgent:
|
||||||
other_powers = [p for p in game.powers if p != power_name]
|
other_powers = [p for p in game.powers if p != power_name]
|
||||||
|
|
||||||
# Create a readable board state string from the board_state dict
|
# Create a readable board state string from the board_state dict
|
||||||
board_state_str = f"Board State:\n"
|
board_state_str = "Board State:\n"
|
||||||
for p_name, power_data in board_state.get("powers", {}).items():
|
for p_name, power_data in board_state.get("powers", {}).items():
|
||||||
# Get units and centers from the board state
|
# Get units and centers from the board state
|
||||||
units = power_data.get("units", [])
|
units = power_data.get("units", [])
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import ast # For literal_eval in JSON fallback parsing
|
import ast # For literal_eval in JSON fallback parsing
|
||||||
import aiohttp # For direct HTTP requests to Responses API
|
import aiohttp # For direct HTTP requests to Responses API
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
|
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Use Async versions of clients
|
# Use Async versions of clients
|
||||||
|
|
@ -21,9 +20,9 @@ from together import AsyncTogether
|
||||||
from together.error import APIError as TogetherAPIError # For specific error handling
|
from together.error import APIError as TogetherAPIError # For specific error handling
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from diplomacy.engine.message import GLOBAL
|
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
||||||
|
|
||||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||||
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
||||||
# Moved formatter imports to avoid circular import - imported locally where needed
|
# Moved formatter imports to avoid circular import - imported locally where needed
|
||||||
|
|
@ -36,6 +35,7 @@ logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 1) Base Interface
|
# 1) Base Interface
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -118,24 +118,19 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=phase,
|
phase=phase,
|
||||||
response_type='order', # Context for run_llm_and_log's own error logging
|
response_type="order", # Context for run_llm_and_log's own error logging
|
||||||
temperature=0
|
temperature=0,
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
|
|
||||||
)
|
)
|
||||||
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
|
||||||
|
|
||||||
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Local import to avoid circular dependency
|
# Local import to avoid circular dependency
|
||||||
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
|
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
|
||||||
|
|
||||||
# Format the natural language response into structured format
|
# Format the natural language response into structured format
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
|
||||||
FORMAT_ORDERS,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -145,9 +140,7 @@ class BaseModelClient:
|
||||||
move_list = self._extract_moves(formatted_response, power_name)
|
move_list = self._extract_moves(formatted_response, power_name)
|
||||||
|
|
||||||
if not move_list:
|
if not move_list:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
|
||||||
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
|
|
||||||
)
|
|
||||||
if model_error_stats is not None and self.model_name in model_error_stats:
|
if model_error_stats is not None and self.model_name in model_error_stats:
|
||||||
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
|
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
|
||||||
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
||||||
|
|
@ -195,7 +188,7 @@ class BaseModelClient:
|
||||||
response_type="order_generation", # Specific type for CSV logging
|
response_type="order_generation", # Specific type for CSV logging
|
||||||
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
|
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
# token_usage and cost can be added later if available and if log_llm_response supports them
|
# token_usage and cost can be added later if available and if log_llm_response supports them
|
||||||
)
|
)
|
||||||
return parsed_orders_for_return
|
return parsed_orders_for_return
|
||||||
|
|
@ -216,9 +209,7 @@ class BaseModelClient:
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
# Some LLMs might not put the colon or might have triple backtick fences.
|
# Some LLMs might not put the colon or might have triple backtick fences.
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
|
||||||
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
||||||
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
||||||
|
|
@ -226,47 +217,35 @@ class BaseModelClient:
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
|
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
|
||||||
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
|
|
||||||
)
|
|
||||||
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
|
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
|
||||||
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
|
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
|
||||||
f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) If still no match, check for triple-backtick code fences containing JSON
|
# 2) If still no match, check for triple-backtick code fences containing JSON
|
||||||
if not matches:
|
if not matches:
|
||||||
code_fence_pattern = r"```json\n(.*?)\n```"
|
code_fence_pattern = r"```json\n(.*?)\n```"
|
||||||
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
|
||||||
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2b) Also try plain ``` code fences without json marker
|
# 2b) Also try plain ``` code fences without json marker
|
||||||
if not matches:
|
if not matches:
|
||||||
code_fence_plain = r"```\n(.*?)\n```"
|
code_fence_plain = r"```\n(.*?)\n```"
|
||||||
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
|
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
|
||||||
f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2c) Try to find bare JSON object anywhere in the response
|
# 2c) Try to find bare JSON object anywhere in the response
|
||||||
if not matches:
|
if not matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
|
||||||
f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
|
|
||||||
)
|
|
||||||
# Look for a JSON object that contains "orders" key
|
# Look for a JSON object that contains "orders" key
|
||||||
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
|
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
|
||||||
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
|
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
|
||||||
f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) Attempt to parse JSON if we found anything
|
# 3) Attempt to parse JSON if we found anything
|
||||||
json_text = None
|
json_text = None
|
||||||
|
|
@ -283,9 +262,7 @@ class BaseModelClient:
|
||||||
json_text = json_text.strip()
|
json_text = json_text.strip()
|
||||||
|
|
||||||
if not json_text:
|
if not json_text:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
|
||||||
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3a) Try JSON loading
|
# 3a) Try JSON loading
|
||||||
|
|
@ -293,14 +270,12 @@ class BaseModelClient:
|
||||||
data = json.loads(json_text)
|
data = json.loads(json_text)
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
|
||||||
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to fix common JSON issues
|
# Try to fix common JSON issues
|
||||||
try:
|
try:
|
||||||
# Remove trailing commas
|
# Remove trailing commas
|
||||||
fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text)
|
fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
|
||||||
# Fix single quotes to double quotes
|
# Fix single quotes to double quotes
|
||||||
fixed_json = fixed_json.replace("'", '"')
|
fixed_json = fixed_json.replace("'", '"')
|
||||||
# Try parsing again
|
# Try parsing again
|
||||||
|
|
@ -308,14 +283,12 @@ class BaseModelClient:
|
||||||
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
|
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
|
||||||
f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to remove inline comments (// style)
|
# Try to remove inline comments (// style)
|
||||||
try:
|
try:
|
||||||
# Remove // comments from each line
|
# Remove // comments from each line
|
||||||
lines = json_text.split('\n')
|
lines = json_text.split("\n")
|
||||||
cleaned_lines = []
|
cleaned_lines = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# Find // that's not inside quotes
|
# Find // that's not inside quotes
|
||||||
|
|
@ -326,12 +299,12 @@ class BaseModelClient:
|
||||||
if escape_next:
|
if escape_next:
|
||||||
escape_next = False
|
escape_next = False
|
||||||
continue
|
continue
|
||||||
if char == '\\':
|
if char == "\\":
|
||||||
escape_next = True
|
escape_next = True
|
||||||
continue
|
continue
|
||||||
if char == '"' and not escape_next:
|
if char == '"' and not escape_next:
|
||||||
in_quotes = not in_quotes
|
in_quotes = not in_quotes
|
||||||
if not in_quotes and line[i:i+2] == '//':
|
if not in_quotes and line[i : i + 2] == "//":
|
||||||
comment_pos = i
|
comment_pos = i
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -342,17 +315,15 @@ class BaseModelClient:
|
||||||
cleaned_line = line
|
cleaned_line = line
|
||||||
cleaned_lines.append(cleaned_line)
|
cleaned_lines.append(cleaned_line)
|
||||||
|
|
||||||
comment_free_json = '\n'.join(cleaned_lines)
|
comment_free_json = "\n".join(cleaned_lines)
|
||||||
# Also remove trailing commas after comment removal
|
# Also remove trailing commas after comment removal
|
||||||
comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json)
|
comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
|
||||||
|
|
||||||
data = json.loads(comment_free_json)
|
data = json.loads(comment_free_json)
|
||||||
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
|
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
|
||||||
f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
||||||
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
||||||
|
|
@ -366,16 +337,12 @@ class BaseModelClient:
|
||||||
if isinstance(moves, list):
|
if isinstance(moves, list):
|
||||||
return moves
|
return moves
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
|
||||||
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If all attempts failed
|
# If all attempts failed
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _validate_orders(
|
def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
||||||
self, moves: List[str], possible_orders: Dict[str, List[str]]
|
|
||||||
) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
|
||||||
"""
|
"""
|
||||||
Filter out invalid moves, fill missing with HOLD, else fallback.
|
Filter out invalid moves, fill missing with HOLD, else fallback.
|
||||||
Returns a tuple: (validated_moves, invalid_moves_found)
|
Returns a tuple: (validated_moves, invalid_moves_found)
|
||||||
|
|
@ -405,15 +372,15 @@ class BaseModelClient:
|
||||||
for loc, orders_list in possible_orders.items():
|
for loc, orders_list in possible_orders.items():
|
||||||
if loc not in used_locs and orders_list:
|
if loc not in used_locs and orders_list:
|
||||||
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
||||||
validated.append(
|
validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
|
||||||
hold_candidates[0] if hold_candidates else orders_list[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
|
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
|
||||||
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
|
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
|
||||||
return self.fallback_orders(possible_orders), []
|
return self.fallback_orders(possible_orders), []
|
||||||
elif not validated and invalid_moves_found: # All LLM moves were invalid
|
elif not validated and invalid_moves_found: # All LLM moves were invalid
|
||||||
logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}")
|
logger.warning(
|
||||||
|
f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
|
||||||
|
)
|
||||||
# We return empty list for validated, but the invalid_moves_found list is populated
|
# We return empty list for validated, but the invalid_moves_found list is populated
|
||||||
return self.fallback_orders(possible_orders), invalid_moves_found
|
return self.fallback_orders(possible_orders), invalid_moves_found
|
||||||
|
|
||||||
|
|
@ -444,7 +411,6 @@ class BaseModelClient:
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||||
|
|
||||||
context = self.build_context_prompt(
|
context = self.build_context_prompt(
|
||||||
|
|
@ -508,7 +474,15 @@ class BaseModelClient:
|
||||||
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
||||||
|
|
||||||
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
||||||
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
|
final_prompt = (
|
||||||
|
final_prompt.replace("AUSTRIA", "Austria")
|
||||||
|
.replace("ENGLAND", "England")
|
||||||
|
.replace("FRANCE", "France")
|
||||||
|
.replace("GERMANY", "Germany")
|
||||||
|
.replace("ITALY", "Italy")
|
||||||
|
.replace("RUSSIA", "Russia")
|
||||||
|
.replace("TURKEY", "Turkey")
|
||||||
|
)
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
||||||
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
||||||
|
|
@ -524,7 +498,6 @@ class BaseModelClient:
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
prompt = self.build_planning_prompt(
|
prompt = self.build_planning_prompt(
|
||||||
game,
|
game,
|
||||||
board_state,
|
board_state,
|
||||||
|
|
@ -545,7 +518,7 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase, # Use game_phase for logging
|
phase=game_phase, # Use game_phase for logging
|
||||||
response_type='plan_reply', # Changed from 'plan' to avoid confusion
|
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
|
||||||
return raw_response
|
return raw_response
|
||||||
|
|
@ -592,7 +565,7 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase,
|
phase=game_phase,
|
||||||
response_type='negotiation', # For run_llm_and_log's internal context
|
response_type="negotiation", # For run_llm_and_log's internal context
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
||||||
|
|
||||||
|
|
@ -600,13 +573,10 @@ class BaseModelClient:
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Local import to avoid circular dependency
|
# Local import to avoid circular dependency
|
||||||
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
|
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
|
||||||
|
|
||||||
# Format the natural language response into structured JSON
|
# Format the natural language response into structured JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
|
||||||
FORMAT_CONVERSATION,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=game_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -632,10 +602,10 @@ class BaseModelClient:
|
||||||
# Original parsing logic as fallback
|
# Original parsing logic as fallback
|
||||||
if not parsed_messages:
|
if not parsed_messages:
|
||||||
# Attempt to find blocks enclosed in {{...}}
|
# Attempt to find blocks enclosed in {{...}}
|
||||||
double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL)
|
double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
|
||||||
if double_brace_blocks:
|
if double_brace_blocks:
|
||||||
# If {{...}} blocks are found, assume each is a self-contained JSON object
|
# If {{...}} blocks are found, assume each is a self-contained JSON object
|
||||||
json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks])
|
json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
|
||||||
else:
|
else:
|
||||||
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
|
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
|
||||||
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
|
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
|
||||||
|
|
@ -650,10 +620,10 @@ class BaseModelClient:
|
||||||
json_blocks = [json.dumps(data)]
|
json_blocks = [json.dumps(data)]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If parsing the whole block fails, fall back to regex for individual objects
|
# If parsing the whole block fails, fall back to regex for individual objects
|
||||||
json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL)
|
json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
|
||||||
else:
|
else:
|
||||||
# If no markdown block, fall back to regex for any JSON object in the response
|
# If no markdown block, fall back to regex for any JSON object in the response
|
||||||
json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL)
|
json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
|
||||||
|
|
||||||
# Process json_blocks if we have them from fallback parsing
|
# Process json_blocks if we have them from fallback parsing
|
||||||
if not parsed_messages and json_blocks:
|
if not parsed_messages and json_blocks:
|
||||||
|
|
@ -661,7 +631,7 @@ class BaseModelClient:
|
||||||
try:
|
try:
|
||||||
cleaned_block = block.strip()
|
cleaned_block = block.strip()
|
||||||
# Attempt to fix common JSON issues like trailing commas before parsing
|
# Attempt to fix common JSON issues like trailing commas before parsing
|
||||||
cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block)
|
cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
|
||||||
parsed_message = json.loads(cleaned_block)
|
parsed_message = json.loads(cleaned_block)
|
||||||
parsed_messages.append(parsed_message)
|
parsed_messages.append(parsed_message)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|
@ -710,7 +680,7 @@ class BaseModelClient:
|
||||||
response_type="negotiation_message",
|
response_type="negotiation_message",
|
||||||
raw_input_prompt=raw_input_prompt,
|
raw_input_prompt=raw_input_prompt,
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
return messages_to_return
|
return messages_to_return
|
||||||
|
|
||||||
|
|
@ -770,7 +740,7 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='plan_generation', # More specific type for run_llm_and_log context
|
response_type="plan_generation", # More specific type for run_llm_and_log context
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
|
||||||
# No parsing needed for the plan, return the raw string
|
# No parsing needed for the plan, return the raw string
|
||||||
|
|
@ -790,7 +760,7 @@ class BaseModelClient:
|
||||||
response_type="plan_generation", # Specific type for CSV logging
|
response_type="plan_generation", # Specific type for CSV logging
|
||||||
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
|
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
|
||||||
raw_response=raw_plan_response,
|
raw_response=raw_plan_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
# token_usage and cost can be added later
|
# token_usage and cost can be added later
|
||||||
)
|
)
|
||||||
return plan_to_return
|
return plan_to_return
|
||||||
|
|
@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
|
||||||
):
|
):
|
||||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
|
|
||||||
self.base_url = (
|
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
|
||||||
base_url
|
|
||||||
or os.environ.get("OPENAI_BASE_URL")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
|
|
@ -832,11 +798,7 @@ class OpenAIClient(BaseModelClient):
|
||||||
inject_random_seed: bool = True,
|
inject_random_seed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
system_prompt_content = (
|
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
|
||||||
if inject_random_seed
|
|
||||||
else self.system_prompt
|
|
||||||
)
|
|
||||||
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
|
|
@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
if not response.content:
|
if not response.content:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
|
||||||
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
return response.content[0].text.strip() if response.content else ""
|
return response.content[0].text.strip() if response.content else ""
|
||||||
except json.JSONDecodeError as json_err:
|
except json.JSONDecodeError as json_err:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
|
||||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
|
||||||
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_config = genai.types.GenerationConfig(
|
generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
|
||||||
temperature=temperature,
|
|
||||||
max_output_tokens=self.max_tokens
|
|
||||||
)
|
|
||||||
response = await self.client.generate_content_async(
|
response = await self.client.generate_content_async(
|
||||||
contents=full_prompt,
|
contents=full_prompt,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response or not response.text:
|
if not response or not response.text:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
|
||||||
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||||
self.client = AsyncDeepSeekOpenAI(
|
self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
|
||||||
api_key=self.api_key,
|
|
||||||
base_url="https://api.deepseek.com/"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||||
try:
|
try:
|
||||||
|
|
@ -986,9 +934,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
||||||
|
|
||||||
if not response or not response.choices:
|
if not response or not response.choices:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] No valid response in generate_response.")
|
||||||
f"[{self.model_name}] No valid response in generate_response."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content = response.choices[0].message.content.strip()
|
content = response.choices[0].message.content.strip()
|
||||||
|
|
@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1037,19 +981,14 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make the API call using aiohttp
|
# Make the API call using aiohttp
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
|
||||||
f"[{self.model_name}] API error (status {response.status}): {error_text}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
|
|
@ -1059,24 +998,18 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
try:
|
try:
|
||||||
outputs = response_data.get("output", [])
|
outputs = response_data.get("output", [])
|
||||||
if len(outputs) < 2:
|
if len(outputs) < 2:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
|
||||||
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# The message is typically in the second output item
|
# The message is typically in the second output item
|
||||||
message_output = outputs[1]
|
message_output = outputs[1]
|
||||||
if message_output.get("type") != "message":
|
if message_output.get("type") != "message":
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
|
||||||
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content_list = message_output.get("content", [])
|
content_list = message_output.get("content", [])
|
||||||
if not content_list:
|
if not content_list:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty content list in message output")
|
||||||
f"[{self.model_name}] Empty content list in message output"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Look for the content item with type 'output_text'
|
# Look for the content item with type 'output_text'
|
||||||
|
|
@ -1087,28 +1020,20 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not text_content:
|
if not text_content:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
|
||||||
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return text_content.strip()
|
return text_content.strip()
|
||||||
|
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
|
||||||
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
|
||||||
f"[{self.model_name}] HTTP client error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1129,10 +1054,7 @@ class OpenRouterClient(BaseModelClient):
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||||
|
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=self.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
|
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
|
||||||
|
|
||||||
|
|
@ -1150,10 +1072,7 @@ class OpenRouterClient(BaseModelClient):
|
||||||
# Prepare standard OpenAI-compatible request
|
# Prepare standard OpenAI-compatible request
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[
|
messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
|
||||||
{"role": "system", "content": system_prompt_content},
|
|
||||||
{"role": "user", "content": prompt_with_cta}
|
|
||||||
],
|
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
@ -1245,7 +1164,7 @@ class TogetherAIClient(BaseModelClient):
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
||||||
##############################################################################
|
##############################################################################
|
||||||
import requests, asyncio
|
|
||||||
|
|
||||||
class RequestsOpenAIClient(BaseModelClient):
|
class RequestsOpenAIClient(BaseModelClient):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1266,11 +1185,7 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
||||||
|
|
||||||
self.base_url = (
|
self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
|
||||||
base_url
|
|
||||||
or os.environ.get("OPENAI_BASE_URL")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
).rstrip("/")
|
|
||||||
|
|
||||||
self.endpoint = f"{self.base_url}/chat/completions"
|
self.endpoint = f"{self.base_url}/chat/completions"
|
||||||
|
|
||||||
|
|
@ -1291,11 +1206,7 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
inject_random_seed: bool = True,
|
inject_random_seed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
system_prompt_content = (
|
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
|
||||||
if inject_random_seed
|
|
||||||
else self.system_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
|
|
@ -1322,7 +1233,6 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 3) Factory to Load Model Client
|
# 3) Factory to Load Model Client
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -1332,6 +1242,7 @@ class ModelSpec(NamedTuple):
|
||||||
base: Optional[str] # 'https://proxy.foo'
|
base: Optional[str] # 'https://proxy.foo'
|
||||||
key: Optional[str] # 'sk-…' (may be None)
|
key: Optional[str] # 'sk-…' (may be None)
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_spec(raw: str) -> ModelSpec:
|
def _parse_model_spec(raw: str) -> ModelSpec:
|
||||||
"""
|
"""
|
||||||
Splits once on '#' (API key) and once on '@' (base URL). A leading
|
Splits once on '#' (API key) and once on '@' (base URL). A leading
|
||||||
|
|
@ -1441,8 +1352,6 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
||||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
||||||
visible = []
|
visible = []
|
||||||
for msg in conversation_messages:
|
for msg in conversation_messages:
|
||||||
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
||||||
if (
|
if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
|
||||||
msg["recipient"] == "ALL"
|
|
||||||
or msg["recipient"] == "GLOBAL"
|
|
||||||
or msg["sender"] == power_name
|
|
||||||
or msg["recipient"] == power_name
|
|
||||||
):
|
|
||||||
visible.append(msg)
|
visible.append(msg)
|
||||||
return visible # already in chronological order if appended that way
|
return visible # already in chronological order if appended that way
|
||||||
|
|
|
||||||
|
|
@ -11,25 +11,28 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prompts_dir:
|
if prompts_dir:
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
filepath = os.path.join(prompts_dir, filename)
|
||||||
else:
|
else:
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||||
filepath = os.path.join(default_prompts_dir, filename)
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def run_diary_consolidation(
|
async def run_diary_consolidation(
|
||||||
agent: 'DiplomacyAgent',
|
agent: "DiplomacyAgent",
|
||||||
game: "Game",
|
game: "Game",
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
entries_to_keep_unsummarized: int = 6,
|
entries_to_keep_unsummarized: int = 6,
|
||||||
|
|
@ -39,70 +42,42 @@ async def run_diary_consolidation(
|
||||||
Consolidate older diary entries while keeping recent ones.
|
Consolidate older diary entries while keeping recent ones.
|
||||||
This is the logic moved from the DiplomacyAgent class.
|
This is the logic moved from the DiplomacyAgent class.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
||||||
f"[{agent.power_name}] CONSOLIDATION START — "
|
|
||||||
f"{len(agent.full_private_diary)} total full entries"
|
|
||||||
)
|
|
||||||
|
|
||||||
full_entries = [
|
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
||||||
e for e in agent.full_private_diary
|
|
||||||
if not e.startswith("[CONSOLIDATED HISTORY]")
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
||||||
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
|
||||||
"skipping consolidation"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||||
if not match:
|
if not match:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
||||||
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
|
||||||
"aborting consolidation"
|
|
||||||
)
|
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
return
|
return
|
||||||
|
|
||||||
cutoff_year = int(match.group(1))
|
cutoff_year = int(match.group(1))
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
|
||||||
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _entry_year(entry: str) -> int | None:
|
def _entry_year(entry: str) -> int | None:
|
||||||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||||
return int(m.group(1)) if m else None
|
return int(m.group(1)) if m else None
|
||||||
|
|
||||||
entries_to_summarize = [
|
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
||||||
e for e in full_entries
|
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
||||||
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
|
||||||
]
|
|
||||||
entries_to_keep = [
|
|
||||||
e for e in full_entries
|
|
||||||
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
||||||
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
|
||||||
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not entries_to_summarize:
|
if not entries_to_summarize:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.warning(
|
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||||
f"[{agent.power_name}] No eligible entries to summarise; "
|
|
||||||
"context diary left unchanged"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
|
|
@ -132,23 +107,14 @@ async def run_diary_consolidation(
|
||||||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||||
success_flag = "TRUE"
|
success_flag = "TRUE"
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
||||||
f"[{agent.power_name}] Consolidation complete — "
|
|
||||||
f"{len(agent.private_diary)} context entries now"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
||||||
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=(
|
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
||||||
consolidation_client.model_name
|
|
||||||
if consolidation_client is not None
|
|
||||||
else agent.client.model_name
|
|
||||||
),
|
|
||||||
power_name=agent.power_name,
|
power_name=agent.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type="diary_consolidation",
|
response_type="diary_consolidation",
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
|
||||||
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
from typing import Optional
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Import logging function and model configuration
|
# Import logging function and model configuration
|
||||||
from .utils import log_llm_response, get_special_models
|
from .utils import log_llm_response, get_special_models
|
||||||
|
|
||||||
# Import client loading function
|
# Import client loading function
|
||||||
from .clients import load_model_client
|
from .clients import load_model_client
|
||||||
|
|
||||||
|
|
@ -26,11 +25,7 @@ FORMAT_ORDER_DIARY = "order_diary"
|
||||||
|
|
||||||
|
|
||||||
async def format_with_gemini_flash(
|
async def format_with_gemini_flash(
|
||||||
raw_response: str,
|
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
|
||||||
format_type: str,
|
|
||||||
power_name: Optional[str] = None,
|
|
||||||
phase: Optional[str] = None,
|
|
||||||
log_file_path: Optional[str] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a natural language response into required JSON structure using Gemini Flash.
|
Format a natural language response into required JSON structure using Gemini Flash.
|
||||||
|
|
@ -52,7 +47,7 @@ async def format_with_gemini_flash(
|
||||||
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
||||||
FORMAT_ORDERS: "formatting/format_orders.txt",
|
FORMAT_ORDERS: "formatting/format_orders.txt",
|
||||||
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
||||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
|
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
|
||||||
}
|
}
|
||||||
|
|
||||||
if format_type not in format_prompts:
|
if format_type not in format_prompts:
|
||||||
|
|
@ -63,7 +58,7 @@ async def format_with_gemini_flash(
|
||||||
if not prompt_file.exists():
|
if not prompt_file.exists():
|
||||||
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
||||||
|
|
||||||
with open(prompt_file, 'r') as f:
|
with open(prompt_file, "r") as f:
|
||||||
format_prompt = f.read()
|
format_prompt = f.read()
|
||||||
|
|
||||||
# Replace placeholder with actual response
|
# Replace placeholder with actual response
|
||||||
|
|
@ -87,7 +82,7 @@ async def format_with_gemini_flash(
|
||||||
formatted_response = await formatter_client.generate_response(
|
formatted_response = await formatter_client.generate_response(
|
||||||
prompt=format_prompt,
|
prompt=format_prompt,
|
||||||
temperature=0, # Deterministic formatting
|
temperature=0, # Deterministic formatting
|
||||||
inject_random_seed=False # No need for random seed in formatting
|
inject_random_seed=False, # No need for random seed in formatting
|
||||||
)
|
)
|
||||||
|
|
||||||
if not formatted_response:
|
if not formatted_response:
|
||||||
|
|
@ -116,7 +111,7 @@ async def format_with_gemini_flash(
|
||||||
response_type=f"format_{format_type}",
|
response_type=f"format_{format_type}",
|
||||||
raw_input_prompt=format_prompt,
|
raw_input_prompt=format_prompt,
|
||||||
raw_response=formatted_response,
|
raw_response=formatted_response,
|
||||||
success="Success"
|
success="Success",
|
||||||
)
|
)
|
||||||
|
|
||||||
return formatted_response
|
return formatted_response
|
||||||
|
|
@ -134,7 +129,7 @@ async def format_with_gemini_flash(
|
||||||
response_type=f"format_{format_type}",
|
response_type=f"format_{format_type}",
|
||||||
raw_input_prompt=format_prompt,
|
raw_input_prompt=format_prompt,
|
||||||
raw_response=f"ERROR: {str(e)}",
|
raw_response=f"ERROR: {str(e)}",
|
||||||
success=f"Failure: {type(e).__name__}"
|
success=f"Failure: {type(e).__name__}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return empty structure based on format type
|
# Return empty structure based on format type
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
logger = logging.getLogger("utils")
|
logger = logging.getLogger("utils")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -24,12 +23,8 @@ class Phase:
|
||||||
plans: Dict[str, str] = field(default_factory=dict)
|
plans: Dict[str, str] = field(default_factory=dict)
|
||||||
messages: List[Message] = field(default_factory=list)
|
messages: List[Message] = field(default_factory=list)
|
||||||
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
||||||
orders_by_power: Dict[str, List[str]] = field(
|
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
|
||||||
default_factory=lambda: defaultdict(list)
|
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
|
||||||
)
|
|
||||||
results_by_power: Dict[str, List[List[str]]] = field(
|
|
||||||
default_factory=lambda: defaultdict(list)
|
|
||||||
)
|
|
||||||
# NEW: Store phase-end summaries provided by each power
|
# NEW: Store phase-end summaries provided by each power
|
||||||
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
||||||
# NEW: Store experience/journal updates from each power for this phase
|
# NEW: Store experience/journal updates from each power for this phase
|
||||||
|
|
@ -39,9 +34,7 @@ class Phase:
|
||||||
self.plans[power_name] = plan
|
self.plans[power_name] = plan
|
||||||
|
|
||||||
def add_message(self, sender: str, recipient: str, content: str):
|
def add_message(self, sender: str, recipient: str, content: str):
|
||||||
self.messages.append(
|
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
|
||||||
Message(sender=sender, recipient=recipient, content=content)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
||||||
self.orders_by_power[power].extend(orders)
|
self.orders_by_power[power].extend(orders)
|
||||||
|
|
@ -110,14 +103,10 @@ class GameHistory:
|
||||||
phase.plans[power_name] = plan
|
phase.plans[power_name] = plan
|
||||||
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
||||||
|
|
||||||
def add_message(
|
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
|
||||||
self, phase_name: str, sender: str, recipient: str, message_content: str
|
|
||||||
):
|
|
||||||
phase = self._get_phase(phase_name)
|
phase = self._get_phase(phase_name)
|
||||||
if phase:
|
if phase:
|
||||||
message = Message(
|
message = Message(sender=sender, recipient=recipient, content=message_content)
|
||||||
sender=sender, recipient=recipient, content=message_content
|
|
||||||
)
|
|
||||||
phase.messages.append(message)
|
phase.messages.append(message)
|
||||||
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
||||||
|
|
||||||
|
|
@ -153,10 +142,6 @@ class GameHistory:
|
||||||
return {}
|
return {}
|
||||||
return self.phases[-1].plans
|
return self.phases[-1].plans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_order_history_for_prompt(
|
def get_order_history_for_prompt(
|
||||||
self,
|
self,
|
||||||
game: "Game",
|
game: "Game",
|
||||||
|
|
@ -194,10 +179,7 @@ class GameHistory:
|
||||||
return str(tag).strip().lower()
|
return str(tag).strip().lower()
|
||||||
|
|
||||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||||
eng2code = {
|
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
|
||||||
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
|
|
||||||
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
|
|
||||||
}
|
|
||||||
norm = game.map.norm
|
norm = game.map.norm
|
||||||
|
|
||||||
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
||||||
|
|
@ -264,13 +246,6 @@ class GameHistory:
|
||||||
return "\n(No orders were issued in recent history)\n"
|
return "\n(No orders were issued in recent history)\n"
|
||||||
return "\n".join(out_lines)
|
return "\n".join(out_lines)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||||
current_phase: Optional[Phase] = None
|
current_phase: Optional[Phase] = None
|
||||||
for phase_obj in self.phases:
|
for phase_obj in self.phases:
|
||||||
|
|
@ -325,11 +300,7 @@ class GameHistory:
|
||||||
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
||||||
# Skip if sender is this power (don't need to respond to own messages)
|
# Skip if sender is this power (don't need to respond to own messages)
|
||||||
if msg.sender != power_name:
|
if msg.sender != power_name:
|
||||||
messages_to_power.append({
|
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
|
||||||
'sender': msg.sender,
|
|
||||||
'content': msg.content,
|
|
||||||
'phase': phase.name
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add debug logging
|
# Add debug logging
|
||||||
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
||||||
|
|
@ -361,10 +332,10 @@ class GameHistory:
|
||||||
for msg in phase.messages:
|
for msg in phase.messages:
|
||||||
# Handle both Message objects and dict objects
|
# Handle both Message objects and dict objects
|
||||||
if isinstance(msg, Message):
|
if isinstance(msg, Message):
|
||||||
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
|
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
|
||||||
sender_messages.append(msg)
|
sender_messages.append(msg)
|
||||||
else: # Assume dict
|
else: # Assume dict
|
||||||
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
|
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
|
||||||
sender_messages.append(msg)
|
sender_messages.append(msg)
|
||||||
|
|
||||||
# Check for responses in this and next phases
|
# Check for responses in this and next phases
|
||||||
|
|
@ -374,8 +345,8 @@ class GameHistory:
|
||||||
recipient = msg.recipient
|
recipient = msg.recipient
|
||||||
msg_content = msg.content
|
msg_content = msg.content
|
||||||
else:
|
else:
|
||||||
recipient = msg['recipient']
|
recipient = msg["recipient"]
|
||||||
msg_content = msg['content']
|
msg_content = msg["content"]
|
||||||
|
|
||||||
# Look for responses in current phase and next phases
|
# Look for responses in current phase and next phases
|
||||||
found_response = False
|
found_response = False
|
||||||
|
|
@ -387,12 +358,14 @@ class GameHistory:
|
||||||
response_msgs = []
|
response_msgs = []
|
||||||
for m in check_phase.messages:
|
for m in check_phase.messages:
|
||||||
if isinstance(m, Message):
|
if isinstance(m, Message):
|
||||||
if m.sender == recipient and (m.recipient == sender_name or
|
if m.sender == recipient and (
|
||||||
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
|
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
|
||||||
|
):
|
||||||
response_msgs.append(m)
|
response_msgs.append(m)
|
||||||
else: # Assume dict
|
else: # Assume dict
|
||||||
if m['sender'] == recipient and (m['recipient'] == sender_name or
|
if m["sender"] == recipient and (
|
||||||
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
|
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
|
||||||
|
):
|
||||||
response_msgs.append(m)
|
response_msgs.append(m)
|
||||||
|
|
||||||
if response_msgs:
|
if response_msgs:
|
||||||
|
|
@ -402,9 +375,6 @@ class GameHistory:
|
||||||
if not found_response:
|
if not found_response:
|
||||||
if recipient not in ignored_by_power:
|
if recipient not in ignored_by_power:
|
||||||
ignored_by_power[recipient] = []
|
ignored_by_power[recipient] = []
|
||||||
ignored_by_power[recipient].append({
|
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
|
||||||
'phase': phase.name,
|
|
||||||
'content': msg_content
|
|
||||||
})
|
|
||||||
|
|
||||||
return ignored_by_power
|
return ignored_by_power
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, Tuple, Optional, Any
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
|
|
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- Serialization / Deserialization ---
|
# --- Serialization / Deserialization ---
|
||||||
|
|
||||||
|
|
||||||
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||||
"""Converts an agent object to a JSON-serializable dictionary."""
|
"""Converts an agent object to a JSON-serializable dictionary."""
|
||||||
return {
|
return {
|
||||||
|
|
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||||
"private_diary": agent.private_diary,
|
"private_diary": agent.private_diary,
|
||||||
}
|
}
|
||||||
|
|
||||||
def deserialize_agent(
|
|
||||||
agent_data: dict,
|
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
||||||
prompts_dir: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
override_model_id: Optional[str] = None
|
|
||||||
) -> DiplomacyAgent:
|
|
||||||
"""
|
"""
|
||||||
Recreates an agent object from a dictionary.
|
Recreates an agent object from a dictionary.
|
||||||
|
|
||||||
|
|
@ -70,6 +67,7 @@ def deserialize_agent(
|
||||||
# game_logic.py
|
# game_logic.py
|
||||||
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
||||||
|
|
||||||
|
|
||||||
def _next_phase_name(short: str) -> str:
|
def _next_phase_name(short: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return the Diplomacy phase string that chronologically follows *short*.
|
Return the Diplomacy phase string that chronologically follows *short*.
|
||||||
|
|
@ -89,13 +87,9 @@ def _next_phase_name(short: str) -> str:
|
||||||
else: # winter→ spring, next year
|
else: # winter→ spring, next year
|
||||||
return f"S{year + 1}M"
|
return f"S{year + 1}M"
|
||||||
|
|
||||||
|
|
||||||
def save_game_state(
|
def save_game_state(
|
||||||
game: Game,
|
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
||||||
agents: Dict[str, DiplomacyAgent],
|
|
||||||
game_history: GameHistory,
|
|
||||||
output_path: str,
|
|
||||||
run_config: Namespace,
|
|
||||||
completed_phase_name: str
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Serialise the entire game to JSON, preserving per-phase custom metadata
|
Serialise the entire game to JSON, preserving per-phase custom metadata
|
||||||
|
|
@ -130,9 +124,7 @@ def save_game_state(
|
||||||
if extras:
|
if extras:
|
||||||
previous_phase_extras[phase["name"]] = extras
|
previous_phase_extras[phase["name"]] = extras
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
||||||
"Could not load previous save to retain metadata: %s", exc, exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# 2. Build the fresh base structure from the diplomacy library. #
|
# 2. Build the fresh base structure from the diplomacy library. #
|
||||||
|
|
@ -143,11 +135,7 @@ def save_game_state(
|
||||||
# 3. Walk every phase and merge the metadata back in. #
|
# 3. Walk every phase and merge the metadata back in. #
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# Capture the *current* snapshot of every live agent exactly once.
|
# Capture the *current* snapshot of every live agent exactly once.
|
||||||
current_state_agents = {
|
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
||||||
p_name: serialize_agent(p_agent)
|
|
||||||
for p_name, p_agent in agents.items()
|
|
||||||
if not game.powers[p_name].is_eliminated()
|
|
||||||
}
|
|
||||||
|
|
||||||
for phase_block in saved_game.get("phases", []):
|
for phase_block in saved_game.get("phases", []):
|
||||||
if int(phase_block["name"][1:5]) > run_config.max_year:
|
if int(phase_block["name"][1:5]) > run_config.max_year:
|
||||||
|
|
@ -167,19 +155,13 @@ def save_game_state(
|
||||||
|
|
||||||
# Plans for this phase – may be empty in non-movement phases.
|
# Plans for this phase – may be empty in non-movement phases.
|
||||||
phase_obj = game_history._get_phase(phase_name)
|
phase_obj = game_history._get_phase(phase_name)
|
||||||
phase_block["state_history_plans"] = (
|
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
||||||
phase_obj.plans if phase_obj else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# 4. Attach top-level metadata and write atomically. #
|
# 4. Attach top-level metadata and write atomically. #
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
||||||
saved_game["final_agent_states"] = {
|
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
||||||
p_name: {"relationships": a.relationships, "goals": a.goals}
|
|
||||||
for p_name, a in agents.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Filter out phases > max_year
|
# Filter out phases > max_year
|
||||||
# saved_game["phases"] = [
|
# saved_game["phases"] = [
|
||||||
|
|
@ -191,7 +173,6 @@ def save_game_state(
|
||||||
logger.info("Game state saved successfully.")
|
logger.info("Game state saved successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_game_state(
|
def load_game_state(
|
||||||
run_dir: str,
|
run_dir: str,
|
||||||
game_file_name: str,
|
game_file_name: str,
|
||||||
|
|
@ -204,7 +185,7 @@ def load_game_state(
|
||||||
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
||||||
|
|
||||||
logger.info(f"Loading game state from: {game_file_path}")
|
logger.info(f"Loading game state from: {game_file_path}")
|
||||||
with open(game_file_path, 'r') as f:
|
with open(game_file_path, "r") as f:
|
||||||
saved_game_data = json.load(f)
|
saved_game_data = json.load(f)
|
||||||
|
|
||||||
# If resuming, find the specified phase and truncate the data after it
|
# If resuming, find the specified phase and truncate the data after it
|
||||||
|
|
@ -213,37 +194,37 @@ def load_game_state(
|
||||||
try:
|
try:
|
||||||
# Find the index of the phase *before* the one we want to resume from.
|
# Find the index of the phase *before* the one we want to resume from.
|
||||||
# We will start the simulation *at* the resume_from_phase.
|
# We will start the simulation *at* the resume_from_phase.
|
||||||
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
|
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
||||||
# Truncate the list to exclude everything after the resume phase
|
# Truncate the list to exclude everything after the resume phase
|
||||||
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
||||||
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
|
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
||||||
|
|
||||||
# Wipe any data that must be regenerated.
|
# Wipe any data that must be regenerated.
|
||||||
for key in ("orders", "results", "messages"):
|
for key in ("orders", "results", "messages"):
|
||||||
saved_game_data['phases'][-1].pop(key, None)
|
saved_game_data["phases"][-1].pop(key, None)
|
||||||
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# If the phase is not found, maybe it's the first phase (S1901M)
|
# If the phase is not found, maybe it's the first phase (S1901M)
|
||||||
if resume_from_phase == "S1901M":
|
if resume_from_phase == "S1901M":
|
||||||
saved_game_data['phases'] = []
|
saved_game_data["phases"] = []
|
||||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
||||||
|
|
||||||
# Reconstruct the Game object
|
# Reconstruct the Game object
|
||||||
last_phase = saved_game_data['phases'][-1]
|
last_phase = saved_game_data["phases"][-1]
|
||||||
|
|
||||||
# Wipe the data that must be regenerated **but preserve the keys**
|
# Wipe the data that must be regenerated **but preserve the keys**
|
||||||
last_phase['orders'] = {} # was dict
|
last_phase["orders"] = {} # was dict
|
||||||
last_phase['results'] = {} # was dict
|
last_phase["results"] = {} # was dict
|
||||||
last_phase['messages'] = []
|
last_phase["messages"] = []
|
||||||
|
|
||||||
game = from_saved_game_format(saved_game_data)
|
game = from_saved_game_format(saved_game_data)
|
||||||
|
|
||||||
game.phase_summaries = saved_game_data.get('phase_summaries', {})
|
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
||||||
|
|
||||||
# Reconstruct agents and game history from the *last* valid phase in the data
|
# Reconstruct agents and game history from the *last* valid phase in the data
|
||||||
if not saved_game_data['phases']:
|
if not saved_game_data["phases"]:
|
||||||
# This happens if we are resuming from the very beginning (S1901M)
|
# This happens if we are resuming from the very beginning (S1901M)
|
||||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||||
agents = {} # Will be created by the main loop
|
agents = {} # Will be created by the main loop
|
||||||
|
|
@ -251,10 +232,10 @@ def load_game_state(
|
||||||
else:
|
else:
|
||||||
# We save the game state up to & including the current (uncompleted) phase.
|
# We save the game state up to & including the current (uncompleted) phase.
|
||||||
# So we need to grab the agent state from the previous (completed) phase.
|
# So we need to grab the agent state from the previous (completed) phase.
|
||||||
if len(saved_game_data['phases']) <= 1:
|
if len(saved_game_data["phases"]) <= 1:
|
||||||
last_phase_data = {}
|
last_phase_data = {}
|
||||||
else:
|
else:
|
||||||
last_phase_data = saved_game_data['phases'][-2]
|
last_phase_data = saved_game_data["phases"][-2]
|
||||||
|
|
||||||
# -------------------- Rebuild agents -------------------- #
|
# -------------------- Rebuild agents -------------------- #
|
||||||
agents = {}
|
agents = {}
|
||||||
|
|
@ -269,18 +250,11 @@ def load_game_state(
|
||||||
elif len(provided) == 1:
|
elif len(provided) == 1:
|
||||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||||
f"Invalid --models argument: expected 1 or {len(powers_order)} "
|
|
||||||
f"items, got {len(provided)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if "state_agents" in last_phase_data:
|
if "state_agents" in last_phase_data:
|
||||||
logger.info("Rebuilding agents from saved state...")
|
logger.info("Rebuilding agents from saved state...")
|
||||||
prompts_dir_from_config = (
|
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
||||||
run_config.prompts_dir
|
|
||||||
if run_config and hasattr(run_config, "prompts_dir")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||||
override_id = power_model_map.get(power_name)
|
override_id = power_model_map.get(power_name)
|
||||||
agents[power_name] = deserialize_agent(
|
agents[power_name] = deserialize_agent(
|
||||||
|
|
@ -290,37 +264,27 @@ def load_game_state(
|
||||||
)
|
)
|
||||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||||
"Cannot resume: 'state_agents' key not found in the last phase "
|
|
||||||
"of the save file."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Rebuild GameHistory
|
# Rebuild GameHistory
|
||||||
game_history = GameHistory()
|
game_history = GameHistory()
|
||||||
logger.info("Rebuilding game history...")
|
logger.info("Rebuilding game history...")
|
||||||
for phase_data in saved_game_data['phases'][:-1]:
|
for phase_data in saved_game_data["phases"][:-1]:
|
||||||
phase_name = phase_data['name']
|
phase_name = phase_data["name"]
|
||||||
game_history.add_phase(phase_name)
|
game_history.add_phase(phase_name)
|
||||||
# Add messages
|
# Add messages
|
||||||
for msg in phase_data.get('messages', []):
|
for msg in phase_data.get("messages", []):
|
||||||
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
|
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||||
# Add plans
|
# Add plans
|
||||||
if 'state_history_plans' in phase_data:
|
if "state_history_plans" in phase_data:
|
||||||
for p_name, plan in phase_data['state_history_plans'].items():
|
for p_name, plan in phase_data["state_history_plans"].items():
|
||||||
game_history.add_plan(phase_name, p_name, plan)
|
game_history.add_plan(phase_name, p_name, plan)
|
||||||
logger.info("Game history rebuilt.")
|
logger.info("Game history rebuilt.")
|
||||||
|
|
||||||
|
|
||||||
return game, agents, game_history, run_config
|
return game, agents, game_history, run_config
|
||||||
|
|
||||||
|
|
||||||
async def initialize_new_game(
|
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
||||||
args: Namespace,
|
|
||||||
game: Game,
|
|
||||||
game_history: GameHistory,
|
|
||||||
llm_log_file_path: str
|
|
||||||
) -> Dict[str, DiplomacyAgent]:
|
|
||||||
"""Initializes agents for a new game."""
|
"""Initializes agents for a new game."""
|
||||||
powers_order = sorted(list(ALL_POWERS))
|
powers_order = sorted(list(ALL_POWERS))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# ai_diplomacy/initialization.py
|
# ai_diplomacy/initialization.py
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
|
|
@ -18,10 +17,11 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def initialize_agent_state_ext(
|
async def initialize_agent_state_ext(
|
||||||
agent: 'DiplomacyAgent',
|
agent: "DiplomacyAgent",
|
||||||
game: 'Game',
|
game: "Game",
|
||||||
game_history: 'GameHistory',
|
game_history: "GameHistory",
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
prompts_dir: Optional[str] = None,
|
prompts_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
@ -37,21 +37,20 @@ async def initialize_agent_state_ext(
|
||||||
try:
|
try:
|
||||||
# Load the prompt template
|
# Load the prompt template
|
||||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||||
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
|
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||||
if not initial_prompt_template:
|
if not initial_prompt_template:
|
||||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Format the prompt with variables
|
# Format the prompt with variables
|
||||||
initial_prompt = initial_prompt_template.format(
|
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||||
power_name=power_name,
|
|
||||||
allowed_labels_str=allowed_labels_str
|
|
||||||
)
|
|
||||||
|
|
||||||
board_state = game.get_state() if game else {}
|
board_state = game.get_state() if game else {}
|
||||||
possible_orders = game.get_all_possible_orders() if game else {}
|
possible_orders = game.get_all_possible_orders() if game else {}
|
||||||
|
|
||||||
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
|
logger.debug(
|
||||||
|
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
|
||||||
|
)
|
||||||
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
||||||
# For initialization, game should always be present.
|
# For initialization, game should always be present.
|
||||||
|
|
||||||
|
|
@ -76,7 +75,7 @@ async def initialize_agent_state_ext(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=current_phase,
|
phase=current_phase,
|
||||||
response_type='initialization', # Context for run_llm_and_log internal error logging
|
response_type="initialization", # Context for run_llm_and_log internal error logging
|
||||||
)
|
)
|
||||||
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
||||||
|
|
||||||
|
|
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Format the natural language response into JSON
|
# Format the natural language response into JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
response,
|
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
|
||||||
FORMAT_INITIAL_STATE,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=current_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -108,12 +103,16 @@ async def initialize_agent_state_ext(
|
||||||
# Defensive check for update_data type if parsing was initially considered successful
|
# Defensive check for update_data type if parsing was initially considered successful
|
||||||
if parsed_successfully:
|
if parsed_successfully:
|
||||||
if isinstance(update_data, str):
|
if isinstance(update_data, str):
|
||||||
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
|
logger.error(
|
||||||
|
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
|
||||||
|
)
|
||||||
update_data = {} # Treat as parsing failure
|
update_data = {} # Treat as parsing failure
|
||||||
parsed_successfully = False
|
parsed_successfully = False
|
||||||
success_status = "Failure: ParsedAsStr"
|
success_status = "Failure: ParsedAsStr"
|
||||||
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
||||||
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
|
logger.error(
|
||||||
|
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
|
||||||
|
)
|
||||||
update_data = {} # Treat as parsing failure
|
update_data = {} # Treat as parsing failure
|
||||||
parsed_successfully = False
|
parsed_successfully = False
|
||||||
success_status = "Failure: NotADict"
|
success_status = "Failure: NotADict"
|
||||||
|
|
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
|
||||||
initial_relationships_applied = False
|
initial_relationships_applied = False
|
||||||
|
|
||||||
if parsed_successfully:
|
if parsed_successfully:
|
||||||
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
|
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
|
||||||
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
|
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
|
||||||
|
|
||||||
if isinstance(initial_goals, list) and initial_goals:
|
if isinstance(initial_goals, list) and initial_goals:
|
||||||
agent.goals = initial_goals
|
agent.goals = initial_goals
|
||||||
|
|
@ -200,7 +199,7 @@ async def initialize_agent_state_ext(
|
||||||
response_type="initial_state_setup", # Specific type for CSV logging
|
response_type="initial_state_setup", # Specific type for CSV logging
|
||||||
raw_input_prompt=full_prompt,
|
raw_input_prompt=full_prompt,
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final log of state after initialization attempt
|
# Final log of state after initialization attempt
|
||||||
|
|
|
||||||
|
|
@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
|
||||||
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
||||||
|
|
||||||
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
|
||||||
from diplomacy.engine.message import Message, GLOBAL
|
from diplomacy.engine.message import Message, GLOBAL
|
||||||
|
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
from .clients import load_model_client
|
from .utils import gather_possible_orders, normalize_recipient_name
|
||||||
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
|
|
@ -21,9 +20,9 @@ load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def conduct_negotiations(
|
async def conduct_negotiations(
|
||||||
game: 'Game',
|
game: "Game",
|
||||||
agents: Dict[str, DiplomacyAgent],
|
agents: Dict[str, DiplomacyAgent],
|
||||||
game_history: 'GameHistory',
|
game_history: "GameHistory",
|
||||||
model_error_stats: Dict[str, Dict[str, int]],
|
model_error_stats: Dict[str, Dict[str, int]],
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
max_rounds: int = 3,
|
max_rounds: int = 3,
|
||||||
|
|
@ -35,12 +34,8 @@ async def conduct_negotiations(
|
||||||
"""
|
"""
|
||||||
logger.info("Starting negotiation phase.")
|
logger.info("Starting negotiation phase.")
|
||||||
|
|
||||||
active_powers = [
|
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||||
]
|
|
||||||
eliminated_powers = [
|
|
||||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Active powers for negotiations: {active_powers}")
|
logger.info(f"Active powers for negotiations: {active_powers}")
|
||||||
if eliminated_powers:
|
if eliminated_powers:
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
from dotenv import load_dotenv
|
|
||||||
import logging
|
import logging
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from .clients import load_model_client
|
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def planning_phase(
|
async def planning_phase(
|
||||||
game,
|
game,
|
||||||
agents: Dict[str, DiplomacyAgent],
|
agents: Dict[str, DiplomacyAgent],
|
||||||
|
|
@ -20,12 +19,8 @@ async def planning_phase(
|
||||||
Lets each power generate a strategic plan using their DiplomacyAgent.
|
Lets each power generate a strategic plan using their DiplomacyAgent.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
||||||
active_powers = [
|
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||||
]
|
|
||||||
eliminated_powers = [
|
|
||||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Active powers for planning: {active_powers}")
|
logger.info(f"Active powers for planning: {active_powers}")
|
||||||
if eliminated_powers:
|
if eliminated_powers:
|
||||||
|
|
@ -35,9 +30,7 @@ async def planning_phase(
|
||||||
|
|
||||||
board_state = game.get_state()
|
board_state = game.get_state()
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
|
||||||
max_workers=len(active_powers)
|
|
||||||
) as executor:
|
|
||||||
futures = {}
|
futures = {}
|
||||||
for power_name in active_powers:
|
for power_name in active_powers:
|
||||||
if power_name not in agents:
|
if power_name not in agents:
|
||||||
|
|
@ -71,16 +64,14 @@ async def planning_phase(
|
||||||
if plan_result.startswith("Error:"):
|
if plan_result.startswith("Error:"):
|
||||||
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
||||||
if power_name in model_error_stats:
|
if power_name in model_error_stats:
|
||||||
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
|
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
|
||||||
model_error_stats[power_name]['planning_generation_errors'] += 1
|
model_error_stats[power_name]["planning_generation_errors"] += 1
|
||||||
else:
|
else:
|
||||||
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
|
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
|
||||||
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
|
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
|
||||||
elif plan_result:
|
elif plan_result:
|
||||||
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
||||||
game_history.add_plan(
|
game_history.add_plan(game.current_short_phase, power_name, plan_result)
|
||||||
game.current_short_phase, power_name, plan_result
|
|
||||||
)
|
|
||||||
logger.debug(f"Added plan for {power_name} to history.")
|
logger.debug(f"Added plan for {power_name} to history.")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Agent {power_name} returned an empty plan.")
|
logger.warning(f"Agent {power_name} returned an empty plan.")
|
||||||
|
|
@ -88,11 +79,11 @@ async def planning_phase(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
||||||
if power_name in model_error_stats:
|
if power_name in model_error_stats:
|
||||||
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
|
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
|
||||||
model_error_stats[power_name]['planning_execution_errors'] += 1
|
model_error_stats[power_name]["planning_execution_errors"] += 1
|
||||||
else:
|
else:
|
||||||
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
|
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
|
||||||
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
|
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
|
||||||
|
|
||||||
logger.info("Planning phase processing complete.")
|
logger.info("Planning phase processing complete.")
|
||||||
return game_history
|
return game_history
|
||||||
|
|
@ -14,6 +14,7 @@ import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
"""
|
"""
|
||||||
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
||||||
|
|
@ -27,9 +28,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
if len(loc.split("/")[0]) == 3
|
if len(loc.split("/")[0]) == 3
|
||||||
}
|
}
|
||||||
|
|
||||||
graph: Dict[str, Dict[str, List[str]]] = {
|
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
|
||||||
p: {"ARMY": [], "FLEET": []} for p in provs
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── helper: list every concrete variant of a province ─────────────────
|
# ── helper: list every concrete variant of a province ─────────────────
|
||||||
def variants(code: str) -> List[str]:
|
def variants(code: str) -> List[str]:
|
||||||
|
|
@ -55,11 +54,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
graph[src]["ARMY"].append(dest)
|
graph[src]["ARMY"].append(dest)
|
||||||
|
|
||||||
# FLEET — any src variant that can host a fleet is valid
|
# FLEET — any src variant that can host a fleet is valid
|
||||||
if any(
|
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
|
||||||
game_map.abuts("F", sv, "-", dv)
|
|
||||||
for sv in src_vers
|
|
||||||
for dv in dest_vers
|
|
||||||
):
|
|
||||||
graph[src]["FLEET"].append(dest)
|
graph[src]["FLEET"].append(dest)
|
||||||
|
|
||||||
# ── tidy up duplicates / order ---------------------------------------
|
# ── tidy up duplicates / order ---------------------------------------
|
||||||
|
|
@ -70,24 +65,23 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def bfs_shortest_path(
|
def bfs_shortest_path(
|
||||||
graph: Dict[str, Dict[str, List[str]]],
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
game_map: GameMap, # Added game_map
|
game_map: GameMap, # Added game_map
|
||||||
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
||||||
unit_type: str,
|
unit_type: str,
|
||||||
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
|
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
|
||||||
) -> Optional[List[str]]: # Returns path of SHORT names
|
) -> Optional[List[str]]: # Returns path of SHORT names
|
||||||
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
||||||
|
|
||||||
# Convert full start location to short province name
|
# Convert full start location to short province name
|
||||||
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
||||||
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||||
start_loc_short = start_loc_short[:3]
|
start_loc_short = start_loc_short[:3]
|
||||||
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
||||||
# A simpler way for non-coastal full (like 'VIE') or already short:
|
# A simpler way for non-coastal full (like 'VIE') or already short:
|
||||||
if '/' not in start_loc_full:
|
if "/" not in start_loc_full:
|
||||||
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
||||||
else: # Has '/', e.g. 'STP/SC'
|
else: # Has '/', e.g. 'STP/SC'
|
||||||
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
||||||
|
|
@ -119,10 +113,11 @@ def bfs_shortest_path(
|
||||||
queue.append((next_loc_short, new_path))
|
queue.append((next_loc_short, new_path))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# --- Helper functions for context generation ---
|
# --- Helper functions for context generation ---
|
||||||
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
||||||
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
||||||
for power, unit_list in board_state.get('units', {}).items():
|
for power, unit_list in board_state.get("units", {}).items():
|
||||||
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
||||||
parts = unit_str.split(" ")
|
parts = unit_str.split(" ")
|
||||||
if len(parts) == 2:
|
if len(parts) == 2:
|
||||||
|
|
@ -131,24 +126,26 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
|
||||||
return f"{parts[0]} {location} ({power})"
|
return f"{parts[0]} {location} ({power})"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
||||||
"""Returns the controlling power's name if the location is an SC, else None."""
|
"""Returns the controlling power's name if the location is an SC, else None."""
|
||||||
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
||||||
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
||||||
if loc_province_name not in game_map.scs:
|
if loc_province_name not in game_map.scs:
|
||||||
return None
|
return None
|
||||||
for power, sc_list in board_state.get('centers', {}).items():
|
for power, sc_list in board_state.get("centers", {}).items():
|
||||||
if loc_province_name in sc_list:
|
if loc_province_name in sc_list:
|
||||||
return power
|
return power
|
||||||
return None # Unowned SC
|
return None # Unowned SC
|
||||||
|
|
||||||
|
|
||||||
def get_shortest_path_to_friendly_unit(
|
def get_shortest_path_to_friendly_unit(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
graph: Dict[str, Dict[str, List[str]]],
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
game_map: GameMap, # Added game_map
|
game_map: GameMap, # Added game_map
|
||||||
power_name: str,
|
power_name: str,
|
||||||
start_unit_loc_full: str,
|
start_unit_loc_full: str,
|
||||||
start_unit_type: str
|
start_unit_type: str,
|
||||||
) -> Optional[Tuple[str, List[str]]]:
|
) -> Optional[Tuple[str, List[str]]]:
|
||||||
"""Finds the shortest path to any friendly unit of the same power."""
|
"""Finds the shortest path to any friendly unit of the same power."""
|
||||||
|
|
||||||
|
|
@ -183,7 +180,7 @@ def get_nearest_enemy_units(
|
||||||
power_name: str,
|
power_name: str,
|
||||||
start_unit_loc_full: str,
|
start_unit_loc_full: str,
|
||||||
start_unit_type: str,
|
start_unit_type: str,
|
||||||
n: int = 3
|
n: int = 3,
|
||||||
) -> List[Tuple[str, List[str]]]:
|
) -> List[Tuple[str, List[str]]]:
|
||||||
"""Finds up to N nearest enemy units, sorted by path length."""
|
"""Finds up to N nearest enemy units, sorted by path length."""
|
||||||
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
||||||
|
|
@ -205,9 +202,9 @@ def get_nearest_enemy_units(
|
||||||
|
|
||||||
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
||||||
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
||||||
if '/' in target_enemy_loc_short:
|
if "/" in target_enemy_loc_short:
|
||||||
target_enemy_loc_short = target_enemy_loc_short[:3]
|
target_enemy_loc_short = target_enemy_loc_short[:3]
|
||||||
if '/' not in target_enemy_loc_full:
|
if "/" not in target_enemy_loc_full:
|
||||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||||
else:
|
else:
|
||||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||||
|
|
@ -281,7 +278,7 @@ def get_adjacent_territory_details(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
||||||
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
||||||
graph: Dict[str, Dict[str, List[str]]]
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a string describing adjacent territories and units that can interact with them."""
|
"""Generates a string describing adjacent territories and units that can interact with them."""
|
||||||
output_lines: List[str] = []
|
output_lines: List[str] = []
|
||||||
|
|
@ -291,9 +288,9 @@ def get_adjacent_territory_details(
|
||||||
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
||||||
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
||||||
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
||||||
if '/' in unit_loc_short:
|
if "/" in unit_loc_short:
|
||||||
unit_loc_short = unit_loc_short[:3]
|
unit_loc_short = unit_loc_short[:3]
|
||||||
if '/' not in unit_loc_full:
|
if "/" not in unit_loc_full:
|
||||||
unit_loc_short = unit_loc_full[:3]
|
unit_loc_short = unit_loc_full[:3]
|
||||||
else:
|
else:
|
||||||
unit_loc_short = unit_loc_full[:3]
|
unit_loc_short = unit_loc_full[:3]
|
||||||
|
|
@ -308,11 +305,11 @@ def get_adjacent_territory_details(
|
||||||
continue
|
continue
|
||||||
processed_adj_provinces.add(adj_loc_short)
|
processed_adj_provinces.add(adj_loc_short)
|
||||||
|
|
||||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
|
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
|
||||||
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
|
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
|
||||||
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
|
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
|
||||||
elif adj_loc_type == 'WATER':
|
elif adj_loc_type == "WATER":
|
||||||
adj_loc_type_display = 'WATER'
|
adj_loc_type_display = "WATER"
|
||||||
else: # SHUT etc.
|
else: # SHUT etc.
|
||||||
adj_loc_type_display = adj_loc_type
|
adj_loc_type_display = adj_loc_type
|
||||||
|
|
||||||
|
|
@ -330,8 +327,7 @@ def get_adjacent_territory_details(
|
||||||
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
||||||
# A true "can support/move to" would require checking possible orders of those further units.
|
# A true "can support/move to" would require checking possible orders of those further units.
|
||||||
# further_adj_provinces are short names from the graph
|
# further_adj_provinces are short names from the graph
|
||||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
|
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
|
||||||
graph.get(adj_loc_short, {}).get('FLEET', [])
|
|
||||||
|
|
||||||
supporting_units_info = []
|
supporting_units_info = []
|
||||||
processed_further_provinces = set()
|
processed_further_provinces = set()
|
||||||
|
|
@ -388,12 +384,12 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
||||||
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
|
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
|
||||||
|
|
||||||
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
||||||
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
||||||
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
|
if loc_type_short == "COAST" or loc_type_short == "LAND":
|
||||||
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
|
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
|
||||||
else:
|
else:
|
||||||
loc_type_display = loc_type_short
|
loc_type_display = loc_type_short
|
||||||
|
|
||||||
|
|
@ -401,7 +397,7 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
||||||
|
|
||||||
# Unit Information section
|
# Unit Information section
|
||||||
current_unit_lines.append(' <UnitInformation>')
|
current_unit_lines.append(" <UnitInformation>")
|
||||||
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
||||||
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
||||||
if sc_owner_at_loc == power_name:
|
if sc_owner_at_loc == power_name:
|
||||||
|
|
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
||||||
current_unit_lines.append(f" {header_content}")
|
current_unit_lines.append(f" {header_content}")
|
||||||
current_unit_lines.append(f" Units present: {unit_str_full}")
|
current_unit_lines.append(f" Units present: {unit_str_full}")
|
||||||
current_unit_lines.append(' </UnitInformation>')
|
current_unit_lines.append(" </UnitInformation>")
|
||||||
|
|
||||||
# Possible moves section
|
# Possible moves section
|
||||||
current_unit_lines.append(' <PossibleMoves>')
|
current_unit_lines.append(" <PossibleMoves>")
|
||||||
current_unit_lines.append(" Possible moves:")
|
current_unit_lines.append(" Possible moves:")
|
||||||
for order_str in unit_specific_possible_orders:
|
for order_str in unit_specific_possible_orders:
|
||||||
current_unit_lines.append(f" {order_str}")
|
current_unit_lines.append(f" {order_str}")
|
||||||
current_unit_lines.append(' </PossibleMoves>')
|
current_unit_lines.append(" </PossibleMoves>")
|
||||||
|
|
||||||
# Nearest enemy units section
|
# Nearest enemy units section
|
||||||
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
||||||
current_unit_lines.append(' <NearestEnemyUnits>')
|
current_unit_lines.append(" <NearestEnemyUnits>")
|
||||||
if enemy_units_info:
|
if enemy_units_info:
|
||||||
current_unit_lines.append(" Nearest units (not ours):")
|
current_unit_lines.append(" Nearest units (not ours):")
|
||||||
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
||||||
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
|
current_unit_lines.append(
|
||||||
|
f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Nearest units (not ours): None found")
|
current_unit_lines.append(" Nearest units (not ours): None found")
|
||||||
current_unit_lines.append(' </NearestEnemyUnits>')
|
current_unit_lines.append(" </NearestEnemyUnits>")
|
||||||
|
|
||||||
# Nearest supply centers (not controlled by us) section
|
# Nearest supply centers (not controlled by us) section
|
||||||
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
||||||
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
|
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
|
||||||
if uncontrolled_scs_info:
|
if uncontrolled_scs_info:
|
||||||
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
||||||
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
||||||
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
|
current_unit_lines.append(
|
||||||
|
f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
||||||
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
|
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
|
||||||
|
|
||||||
# Adjacent territories details section
|
# Adjacent territories details section
|
||||||
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
||||||
current_unit_lines.append(' <AdjacentTerritories>')
|
current_unit_lines.append(" <AdjacentTerritories>")
|
||||||
if adj_details_str:
|
if adj_details_str:
|
||||||
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
||||||
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
||||||
# We might need to indent adj_details_str if it's a single block of text
|
# We might need to indent adj_details_str if it's a single block of text
|
||||||
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
||||||
if '\n' in adj_details_str:
|
if "\n" in adj_details_str:
|
||||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
|
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
|
||||||
current_unit_lines.append(indented_adj_details)
|
current_unit_lines.append(indented_adj_details)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(f" {adj_details_str}")
|
current_unit_lines.append(f" {adj_details_str}")
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
|
current_unit_lines.append(
|
||||||
current_unit_lines.append(' </AdjacentTerritories>')
|
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
|
||||||
|
) # Added more descriptive else
|
||||||
|
current_unit_lines.append(" </AdjacentTerritories>")
|
||||||
|
|
||||||
current_unit_lines.append(' </UnitContext>')
|
current_unit_lines.append(" </UnitContext>")
|
||||||
final_context_lines.extend(current_unit_lines)
|
final_context_lines.extend(current_unit_lines)
|
||||||
|
|
||||||
final_context_lines.append("</PossibleOrdersContext>")
|
final_context_lines.append("</PossibleOrdersContext>")
|
||||||
|
|
@ -469,7 +471,6 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
# Regex and tiny helpers
|
# Regex and tiny helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Tuple, List, Dict, Optional, Any
|
from typing import Tuple, List, Dict, Optional, Any
|
||||||
|
|
||||||
# ── order-syntax matchers ─────────────────────────────────────────────────
|
# ── order-syntax matchers ─────────────────────────────────────────────────
|
||||||
|
|
@ -478,6 +479,7 @@ _HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
||||||
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||||
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
||||||
|
|
||||||
|
|
||||||
def _is_hold_order(order: str) -> bool: # NEW
|
def _is_hold_order(order: str) -> bool: # NEW
|
||||||
return bool(_HOLD_RE.match(order.strip()))
|
return bool(_HOLD_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
|
||||||
"""Trim & uppercase for reliable comparisons."""
|
"""Trim & uppercase for reliable comparisons."""
|
||||||
return name.strip().upper()
|
return name.strip().upper()
|
||||||
|
|
||||||
|
|
||||||
def _is_simple_move(order: str) -> bool:
|
def _is_simple_move(order: str) -> bool:
|
||||||
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _is_retreat_order(order: str) -> bool:
|
def _is_retreat_order(order: str) -> bool:
|
||||||
return bool(_RETREAT_RE.match(order.strip()))
|
return bool(_RETREAT_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _is_adjust_order(order: str) -> bool:
|
def _is_adjust_order(order: str) -> bool:
|
||||||
return bool(_ADJUST_RE.match(order.strip()))
|
return bool(_ADJUST_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _split_move(order: str) -> Tuple[str, str]:
|
def _split_move(order: str) -> Tuple[str, str]:
|
||||||
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
||||||
unit_part, dest = order.split(" - ")
|
unit_part, dest = order.split(" - ")
|
||||||
|
|
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
|
||||||
# Gather *all* friendly support orders for a given move
|
# Gather *all* friendly support orders for a given move
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _all_support_examples(
|
def _all_support_examples(
|
||||||
mover: str,
|
mover: str,
|
||||||
dest: str,
|
dest: str,
|
||||||
|
|
@ -526,6 +533,7 @@ def _all_support_examples(
|
||||||
|
|
||||||
return supports
|
return supports
|
||||||
|
|
||||||
|
|
||||||
def _all_support_hold_examples(
|
def _all_support_hold_examples(
|
||||||
holder: str,
|
holder: str,
|
||||||
all_orders: Dict[str, List[str]],
|
all_orders: Dict[str, List[str]],
|
||||||
|
|
@ -550,6 +558,7 @@ def _all_support_hold_examples(
|
||||||
# Province-type resolver (handles short codes, coasts, seas)
|
# Province-type resolver (handles short codes, coasts, seas)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _province_type_display(game_map, prov_short: str) -> str:
|
def _province_type_display(game_map, prov_short: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
||||||
|
|
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
|
||||||
# Adjacent-territory lines (used by movement-phase builder)
|
# Adjacent-territory lines (used by movement-phase builder)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _adjacent_territory_lines(
|
def _adjacent_territory_lines(
|
||||||
graph,
|
graph,
|
||||||
game_map,
|
game_map,
|
||||||
|
|
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
|
||||||
pwr = unit_here.split(" ")[-1].strip("()")
|
pwr = unit_here.split(" ")[-1].strip("()")
|
||||||
if pwr == our_power:
|
if pwr == our_power:
|
||||||
friend_descr = unit_here.split(" (")[0]
|
friend_descr = unit_here.split(" (")[0]
|
||||||
lines.append(
|
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
|
||||||
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
lines.append(
|
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa")
|
||||||
f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa"
|
|
||||||
)
|
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
|
||||||
# Movement-phase generator (UNCHANGED LOGIC)
|
# Movement-phase generator (UNCHANGED LOGIC)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_movement(
|
def _generate_rich_order_context_movement(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
|
||||||
if _norm_power(unit_power) != me:
|
if _norm_power(unit_power) != me:
|
||||||
continue # Skip units that aren’t ours
|
continue # Skip units that aren’t ours
|
||||||
|
|
||||||
mover_descr, _ = _split_move(
|
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
|
||||||
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
|
|
||||||
)
|
|
||||||
|
|
||||||
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
||||||
prov_type_disp = _province_type_display(game_map, prov_short)
|
prov_type_disp = _province_type_display(game_map, prov_short)
|
||||||
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
||||||
|
|
||||||
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
||||||
owner_line = (
|
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||||
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ind = " "
|
ind = " "
|
||||||
block: List[str] = [f"<Territory {prov_short}>"]
|
block: List[str] = [f"<Territory {prov_short}>"]
|
||||||
|
|
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
|
||||||
|
|
||||||
# ----- Adjacent territories -----
|
# ----- Adjacent territories -----
|
||||||
block.append("# Adjacent territories:")
|
block.append("# Adjacent territories:")
|
||||||
block.extend(
|
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
|
||||||
_adjacent_territory_lines(
|
|
||||||
graph, game_map, board_state,
|
|
||||||
unit_loc_full, mover_descr, power_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ----- Nearest enemy units -----
|
# ----- Nearest enemy units -----
|
||||||
block.append("# Nearest units (not ours):")
|
block.append("# Nearest units (not ours):")
|
||||||
enemies = get_nearest_enemy_units(
|
enemies = get_nearest_enemy_units(
|
||||||
board_state, graph, game_map,
|
board_state,
|
||||||
power_name, unit_loc_full,
|
graph,
|
||||||
|
game_map,
|
||||||
|
power_name,
|
||||||
|
unit_loc_full,
|
||||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||||
n=3,
|
n=3,
|
||||||
)
|
)
|
||||||
|
|
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
|
||||||
# ----- Nearest uncontrolled SCs -----
|
# ----- Nearest uncontrolled SCs -----
|
||||||
block.append("# Nearest supply centers (not controlled by us):")
|
block.append("# Nearest supply centers (not controlled by us):")
|
||||||
scs = get_nearest_uncontrolled_scs(
|
scs = get_nearest_uncontrolled_scs(
|
||||||
game_map, board_state, graph,
|
game_map,
|
||||||
power_name, unit_loc_full,
|
board_state,
|
||||||
|
graph,
|
||||||
|
power_name,
|
||||||
|
unit_loc_full,
|
||||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||||
n=3,
|
n=3,
|
||||||
)
|
)
|
||||||
|
|
@ -737,9 +741,7 @@ def _generate_rich_order_context_movement(
|
||||||
# ---- Moves (same behaviour as before) ----
|
# ---- Moves (same behaviour as before) ----
|
||||||
for mv in simple_moves:
|
for mv in simple_moves:
|
||||||
mover, dest = _split_move(mv)
|
mover, dest = _split_move(mv)
|
||||||
occ = _dest_occupancy_desc(
|
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
|
||||||
dest.split("/")[0][:3], game_map, board_state, power_name
|
|
||||||
)
|
|
||||||
block.append(f"{ind}{mv} {occ}")
|
block.append(f"{ind}{mv} {occ}")
|
||||||
|
|
||||||
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
||||||
|
|
@ -753,7 +755,6 @@ def _generate_rich_order_context_movement(
|
||||||
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
||||||
block.append(f"{ind * 2}Available Support: {s}")
|
block.append(f"{ind * 2}Available Support: {s}")
|
||||||
|
|
||||||
|
|
||||||
block.append(f"</Territory {prov_short}>")
|
block.append(f"</Territory {prov_short}>")
|
||||||
blocks.append("\n".join(block))
|
blocks.append("\n".join(block))
|
||||||
|
|
||||||
|
|
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
|
||||||
# Retreat-phase builder – echo orders verbatim, no tags
|
# Retreat-phase builder – echo orders verbatim, no tags
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_retreat(
|
def _generate_rich_order_context_retreat(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
|
||||||
return "\n".join(lines) if lines else "(No dislodged units)"
|
return "\n".join(lines) if lines else "(No dislodged units)"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_adjustment(
|
def _generate_rich_order_context_adjustment(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
|
||||||
return "\n".join(lines) if len(lines) > 1 else summary
|
return "\n".join(lines) if len(lines) > 1 else summary
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Phase-dispatch wrapper (public entry point)
|
# Phase-dispatch wrapper (public entry point)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def generate_rich_order_context(
|
def generate_rich_order_context(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -847,21 +847,13 @@ def generate_rich_order_context(
|
||||||
phase_type = game.current_short_phase[-1]
|
phase_type = game.current_short_phase[-1]
|
||||||
|
|
||||||
if phase_type == "M": # Movement
|
if phase_type == "M": # Movement
|
||||||
return _generate_rich_order_context_movement(
|
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
if phase_type == "R": # Retreat
|
if phase_type == "R": # Retreat
|
||||||
return _generate_rich_order_context_retreat(
|
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
if phase_type == "A": # Adjustment (build / disband)
|
if phase_type == "A": # Adjustment (build / disband)
|
||||||
return _generate_rich_order_context_adjustment(
|
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback – treat unknown formats as movement
|
# Fallback – treat unknown formats as movement
|
||||||
return _generate_rich_order_context_movement(
|
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from .possible_order_context import (
|
||||||
generate_rich_order_context,
|
generate_rich_order_context,
|
||||||
generate_rich_order_context_xml,
|
generate_rich_order_context_xml,
|
||||||
)
|
)
|
||||||
import os
|
|
||||||
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
||||||
|
|
||||||
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
||||||
|
|
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
|
||||||
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
||||||
|
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ def atomic_write_json(data: dict, filepath: str):
|
||||||
|
|
||||||
# Write to a temporary file in the same directory
|
# Write to a temporary file in the same directory
|
||||||
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
||||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
with open(temp_filepath, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
# Atomically rename the temporary file to the final destination
|
# Atomically rename the temporary file to the final destination
|
||||||
|
|
@ -110,11 +110,7 @@ def get_special_models() -> Dict[str, str]:
|
||||||
# Use a different Gemini model for formatting
|
# Use a different Gemini model for formatting
|
||||||
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
||||||
"""
|
"""
|
||||||
return {
|
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
|
||||||
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
|
|
||||||
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||||
|
|
@ -175,15 +171,19 @@ async def get_valid_orders(
|
||||||
|
|
||||||
# ── 2. Type check ──────────────────────────────────────────
|
# ── 2. Type check ──────────────────────────────────────────
|
||||||
if not isinstance(raw_orders, list):
|
if not isinstance(raw_orders, list):
|
||||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
|
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
|
||||||
power_name, raw_orders)
|
|
||||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||||||
|
|
||||||
# ── 3. Round-trip validation with engine ───────────────────
|
# ── 3. Round-trip validation with engine ───────────────────
|
||||||
CODE_TO_ENGINE = {
|
CODE_TO_ENGINE = {
|
||||||
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
|
"AUT": "AUSTRIA",
|
||||||
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
|
"ENG": "ENGLAND",
|
||||||
|
"FRA": "FRANCE",
|
||||||
|
"GER": "GERMANY",
|
||||||
|
"ITA": "ITALY",
|
||||||
|
"RUS": "RUSSIA",
|
||||||
|
"TUR": "TURKEY",
|
||||||
}
|
}
|
||||||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||||||
|
|
||||||
|
|
@ -226,8 +226,6 @@ async def get_valid_orders(
|
||||||
return {"valid": valid, "invalid": invalid}
|
return {"valid": valid, "invalid": invalid}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_and_compare_orders(
|
def normalize_and_compare_orders(
|
||||||
issued_orders: Dict[str, List[str]],
|
issued_orders: Dict[str, List[str]],
|
||||||
accepted_orders_dict: Dict[str, List[str]],
|
accepted_orders_dict: Dict[str, List[str]],
|
||||||
|
|
@ -348,8 +346,6 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# == New LLM Response Logging Function ==
|
# == New LLM Response Logging Function ==
|
||||||
def log_llm_response(
|
def log_llm_response(
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
|
|
@ -374,14 +370,18 @@ def log_llm_response(
|
||||||
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||||||
# Added "raw_input" to fieldnames
|
# Added "raw_input" to fieldnames
|
||||||
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
|
writer = csv.DictWriter(
|
||||||
|
csvfile,
|
||||||
|
fieldnames=fieldnames,
|
||||||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||||||
escapechar='\\') # Use backslash for escaping
|
escapechar="\\",
|
||||||
|
) # Use backslash for escaping
|
||||||
|
|
||||||
if not file_exists:
|
if not file_exists:
|
||||||
writer.writeheader() # Write header only if file is new
|
writer.writeheader() # Write header only if file is new
|
||||||
|
|
||||||
writer.writerow({
|
writer.writerow(
|
||||||
|
{
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||||||
"phase": phase,
|
"phase": phase,
|
||||||
|
|
@ -389,12 +389,12 @@ def log_llm_response(
|
||||||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||||||
"raw_response": raw_response,
|
"raw_response": raw_response,
|
||||||
"success": success,
|
"success": success,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def run_llm_and_log(
|
async def run_llm_and_log(
|
||||||
client: "BaseModelClient",
|
client: "BaseModelClient",
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -428,8 +428,7 @@ async def run_llm_and_log(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == attempts - 1:
|
if attempt == attempts - 1:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
|
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
|
||||||
f"in phase {phase}: {e}",
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Back-off before the next attempt (unless this was the last)
|
# Back-off before the next attempt (unless this was the last)
|
||||||
|
|
@ -439,7 +438,6 @@ async def run_llm_and_log(
|
||||||
return raw_response
|
return raw_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# This generates a few lines of random alphanum chars to inject into the
|
# This generates a few lines of random alphanum chars to inject into the
|
||||||
# system prompt. This lets us use temp=0 while still getting variation
|
# system prompt. This lets us use temp=0 while still getting variation
|
||||||
# between trials.
|
# between trials.
|
||||||
|
|
@ -447,15 +445,8 @@ async def run_llm_and_log(
|
||||||
# ensure valid json outputs.
|
# ensure valid json outputs.
|
||||||
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
||||||
# Generate x lines of y random alphanumeric characters
|
# Generate x lines of y random alphanumeric characters
|
||||||
seed_lines = [
|
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
|
||||||
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
|
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
|
||||||
for _ in range(n_lines)
|
|
||||||
]
|
|
||||||
random_seed_block = (
|
|
||||||
"<RANDOM SEED PLEASE IGNORE>\n" +
|
|
||||||
"\n".join(seed_lines) +
|
|
||||||
"\n</RANDOM SEED>"
|
|
||||||
)
|
|
||||||
return random_seed_block
|
return random_seed_block
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -473,6 +464,7 @@ def get_prompt_path(prompt_name: str) -> str:
|
||||||
else:
|
else:
|
||||||
return prompt_name
|
return prompt_name
|
||||||
|
|
||||||
|
|
||||||
def normalize_recipient_name(recipient: str) -> str:
|
def normalize_recipient_name(recipient: str) -> str:
|
||||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||||
if not recipient:
|
if not recipient:
|
||||||
|
|
@ -482,24 +474,24 @@ def normalize_recipient_name(recipient: str) -> str:
|
||||||
|
|
||||||
# Handle common LLM typos and abbreviations found in data
|
# Handle common LLM typos and abbreviations found in data
|
||||||
name_mapping = {
|
name_mapping = {
|
||||||
'EGMANY': 'GERMANY',
|
"EGMANY": "GERMANY",
|
||||||
'GERMAN': 'GERMANY',
|
"GERMAN": "GERMANY",
|
||||||
'UK': 'ENGLAND',
|
"UK": "ENGLAND",
|
||||||
'BRIT': 'ENGLAND',
|
"BRIT": "ENGLAND",
|
||||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
"ENGLAND": "ENGLAND", # Keep as-is
|
||||||
'FRANCE': 'FRANCE', # Keep as-is
|
"FRANCE": "FRANCE", # Keep as-is
|
||||||
'GERMANY': 'GERMANY', # Keep as-is
|
"GERMANY": "GERMANY", # Keep as-is
|
||||||
'ITALY': 'ITALY', # Keep as-is
|
"ITALY": "ITALY", # Keep as-is
|
||||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
"AUSTRIA": "AUSTRIA", # Keep as-is
|
||||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
"RUSSIA": "RUSSIA", # Keep as-is
|
||||||
'TURKEY': 'TURKEY', # Keep as-is
|
"TURKEY": "TURKEY", # Keep as-is
|
||||||
'Germany': 'GERMANY',
|
"Germany": "GERMANY",
|
||||||
'England': 'ENGLAND',
|
"England": "ENGLAND",
|
||||||
'France': 'FRANCE',
|
"France": "FRANCE",
|
||||||
'Italy': 'ITALY',
|
"Italy": "ITALY",
|
||||||
'Russia': 'RUSSIA',
|
"Russia": "RUSSIA",
|
||||||
'Austria': 'AUSTRIA',
|
"Austria": "AUSTRIA",
|
||||||
'Turkey': 'TURKEY',
|
"Turkey": "TURKEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
normalized = name_mapping.get(recipient, recipient)
|
normalized = name_mapping.get(recipient, recipient)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue