mirror of
https://github.com/GoodStartLabs/AI_Diplomacy.git
synced 2026-05-02 17:46:00 +00:00
Ruff format and check
This commit is contained in:
parent
aadedd6512
commit
ad817695ef
13 changed files with 738 additions and 973 deletions
|
|
@ -10,7 +10,7 @@ import ast
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
# Assuming BaseModelClient is importable from clients.py in the same directory
|
# Assuming BaseModelClient is importable from clients.py in the same directory
|
||||||
from .clients import BaseModelClient, load_model_client
|
from .clients import BaseModelClient
|
||||||
|
|
||||||
# Import load_prompt and the new logging wrapper from utils
|
# Import load_prompt and the new logging wrapper from utils
|
||||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
|
||||||
|
|
@ -336,7 +336,7 @@ class DiplomacyAgent:
|
||||||
)
|
)
|
||||||
# If even the last resort doesn't give a dict, return empty dict
|
# If even the last resort doesn't give a dict, return empty dict
|
||||||
return {}
|
return {}
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
@ -965,7 +965,7 @@ class DiplomacyAgent:
|
||||||
other_powers = [p for p in game.powers if p != power_name]
|
other_powers = [p for p in game.powers if p != power_name]
|
||||||
|
|
||||||
# Create a readable board state string from the board_state dict
|
# Create a readable board state string from the board_state dict
|
||||||
board_state_str = f"Board State:\n"
|
board_state_str = "Board State:\n"
|
||||||
for p_name, power_data in board_state.get("powers", {}).items():
|
for p_name, power_data in board_state.get("powers", {}).items():
|
||||||
# Get units and centers from the board state
|
# Get units and centers from the board state
|
||||||
units = power_data.get("units", [])
|
units = power_data.get("units", [])
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,41 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
import ast # For literal_eval in JSON fallback parsing
|
import ast # For literal_eval in JSON fallback parsing
|
||||||
import aiohttp # For direct HTTP requests to Responses API
|
import aiohttp # For direct HTTP requests to Responses API
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
|
from typing import List, Dict, Optional, Tuple, NamedTuple
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Use Async versions of clients
|
# Use Async versions of clients
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
|
from openai import AsyncOpenAI as AsyncDeepSeekOpenAI # Alias for clarity
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from together import AsyncTogether
|
from together import AsyncTogether
|
||||||
from together.error import APIError as TogetherAPIError # For specific error handling
|
from together.error import APIError as TogetherAPIError # For specific error handling
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from diplomacy.engine.message import GLOBAL
|
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
|
||||||
|
|
||||||
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
|
||||||
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
|
||||||
# Moved formatter imports to avoid circular import - imported locally where needed
|
# Moved formatter imports to avoid circular import - imported locally where needed
|
||||||
|
|
||||||
# set logger back to just info
|
# set logger back to just info
|
||||||
logger = logging.getLogger("client")
|
logger = logging.getLogger("client")
|
||||||
logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
|
logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
|
||||||
# Note: BasicConfig might conflict if already configured in lm_game. Keep client-specific for now.
|
# Note: BasicConfig might conflict if already configured in lm_game. Keep client-specific for now.
|
||||||
# logging.basicConfig(level=logging.DEBUG) # Might be redundant if lm_game configures root
|
# logging.basicConfig(level=logging.DEBUG) # Might be redundant if lm_game configures root
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 1) Base Interface
|
# 1) Base Interface
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -76,13 +76,13 @@ class BaseModelClient:
|
||||||
board_state,
|
board_state,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
possible_orders: Dict[str, List[str]],
|
possible_orders: Dict[str, List[str]],
|
||||||
conversation_text: str, # This is GameHistory
|
conversation_text: str, # This is GameHistory
|
||||||
model_error_stats: dict,
|
model_error_stats: dict,
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
phase: str,
|
phase: str,
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
1) Builds the prompt with conversation context if available
|
1) Builds the prompt with conversation context if available
|
||||||
|
|
@ -98,7 +98,7 @@ class BaseModelClient:
|
||||||
board_state=board_state,
|
board_state=board_state,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
possible_orders=possible_orders,
|
possible_orders=possible_orders,
|
||||||
game_history=game_history_obj, # Pass GameHistory object
|
game_history=game_history_obj, # Pass GameHistory object
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary_str=agent_private_diary_str,
|
agent_private_diary_str=agent_private_diary_str,
|
||||||
|
|
@ -108,7 +108,7 @@ class BaseModelClient:
|
||||||
raw_response = ""
|
raw_response = ""
|
||||||
# Initialize success status. Will be updated based on outcome.
|
# Initialize success status. Will be updated based on outcome.
|
||||||
success_status = "Failure: Initialized"
|
success_status = "Failure: Initialized"
|
||||||
parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback
|
parsed_orders_for_return = self.fallback_orders(possible_orders) # Default to fallback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Call LLM using the logging wrapper
|
# Call LLM using the logging wrapper
|
||||||
|
|
@ -118,24 +118,19 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=phase,
|
phase=phase,
|
||||||
response_type='order', # Context for run_llm_and_log's own error logging
|
response_type="order", # Context for run_llm_and_log's own error logging
|
||||||
temperature=0
|
temperature=0,
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
|
|
||||||
)
|
)
|
||||||
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
|
||||||
|
|
||||||
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Local import to avoid circular dependency
|
# Local import to avoid circular dependency
|
||||||
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
|
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
|
||||||
|
|
||||||
# Format the natural language response into structured format
|
# Format the natural language response into structured format
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
|
||||||
FORMAT_ORDERS,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -145,9 +140,7 @@ class BaseModelClient:
|
||||||
move_list = self._extract_moves(formatted_response, power_name)
|
move_list = self._extract_moves(formatted_response, power_name)
|
||||||
|
|
||||||
if not move_list:
|
if not move_list:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
|
||||||
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
|
|
||||||
)
|
|
||||||
if model_error_stats is not None and self.model_name in model_error_stats:
|
if model_error_stats is not None and self.model_name in model_error_stats:
|
||||||
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
|
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
|
||||||
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
model_error_stats[self.model_name]["order_decoding_errors"] += 1
|
||||||
|
|
@ -173,8 +166,8 @@ class BaseModelClient:
|
||||||
# because the LLM didn't provide a fully usable set of orders without intervention/fallbacks.
|
# because the LLM didn't provide a fully usable set of orders without intervention/fallbacks.
|
||||||
# The fallback_orders logic within _validate_orders might fill in missing pieces,
|
# The fallback_orders logic within _validate_orders might fill in missing pieces,
|
||||||
# but the key is that the LLM *proposed* invalid moves.
|
# but the key is that the LLM *proposed* invalid moves.
|
||||||
if not validated_moves: # All LLM moves were invalid
|
if not validated_moves: # All LLM moves were invalid
|
||||||
logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}")
|
logger.warning(f"[{power_name}] All LLM-proposed moves were invalid. Using fallbacks. Invalid: {invalid_moves_list}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"[{power_name}] Some LLM-proposed moves were invalid. Using fallbacks/validated. Invalid: {invalid_moves_list}")
|
logger.info(f"[{power_name}] Some LLM-proposed moves were invalid. Using fallbacks/validated. Invalid: {invalid_moves_list}")
|
||||||
else:
|
else:
|
||||||
|
|
@ -186,16 +179,16 @@ class BaseModelClient:
|
||||||
# Fallback is already set to parsed_orders_for_return
|
# Fallback is already set to parsed_orders_for_return
|
||||||
finally:
|
finally:
|
||||||
# Log the attempt regardless of outcome
|
# Log the attempt regardless of outcome
|
||||||
if log_file_path: # Only log if a path is provided
|
if log_file_path: # Only log if a path is provided
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=phase,
|
phase=phase,
|
||||||
response_type="order_generation", # Specific type for CSV logging
|
response_type="order_generation", # Specific type for CSV logging
|
||||||
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
|
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
# token_usage and cost can be added later if available and if log_llm_response supports them
|
# token_usage and cost can be added later if available and if log_llm_response supports them
|
||||||
)
|
)
|
||||||
return parsed_orders_for_return
|
return parsed_orders_for_return
|
||||||
|
|
@ -216,9 +209,7 @@ class BaseModelClient:
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
# Some LLMs might not put the colon or might have triple backtick fences.
|
# Some LLMs might not put the colon or might have triple backtick fences.
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
|
||||||
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
|
||||||
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
|
||||||
|
|
@ -226,47 +217,35 @@ class BaseModelClient:
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
|
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
|
||||||
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
|
|
||||||
)
|
|
||||||
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
|
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
|
||||||
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
|
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
|
||||||
f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) If still no match, check for triple-backtick code fences containing JSON
|
# 2) If still no match, check for triple-backtick code fences containing JSON
|
||||||
if not matches:
|
if not matches:
|
||||||
code_fence_pattern = r"```json\n(.*?)\n```"
|
code_fence_pattern = r"```json\n(.*?)\n```"
|
||||||
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
|
||||||
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2b) Also try plain ``` code fences without json marker
|
# 2b) Also try plain ``` code fences without json marker
|
||||||
if not matches:
|
if not matches:
|
||||||
code_fence_plain = r"```\n(.*?)\n```"
|
code_fence_plain = r"```\n(.*?)\n```"
|
||||||
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
|
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
|
||||||
f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2c) Try to find bare JSON object anywhere in the response
|
# 2c) Try to find bare JSON object anywhere in the response
|
||||||
if not matches:
|
if not matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
|
||||||
f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
|
|
||||||
)
|
|
||||||
# Look for a JSON object that contains "orders" key
|
# Look for a JSON object that contains "orders" key
|
||||||
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
|
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
|
||||||
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
|
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
|
||||||
f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) Attempt to parse JSON if we found anything
|
# 3) Attempt to parse JSON if we found anything
|
||||||
json_text = None
|
json_text = None
|
||||||
|
|
@ -283,9 +262,7 @@ class BaseModelClient:
|
||||||
json_text = json_text.strip()
|
json_text = json_text.strip()
|
||||||
|
|
||||||
if not json_text:
|
if not json_text:
|
||||||
logger.debug(
|
logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
|
||||||
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3a) Try JSON loading
|
# 3a) Try JSON loading
|
||||||
|
|
@ -293,14 +270,12 @@ class BaseModelClient:
|
||||||
data = json.loads(json_text)
|
data = json.loads(json_text)
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
|
||||||
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to fix common JSON issues
|
# Try to fix common JSON issues
|
||||||
try:
|
try:
|
||||||
# Remove trailing commas
|
# Remove trailing commas
|
||||||
fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text)
|
fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
|
||||||
# Fix single quotes to double quotes
|
# Fix single quotes to double quotes
|
||||||
fixed_json = fixed_json.replace("'", '"')
|
fixed_json = fixed_json.replace("'", '"')
|
||||||
# Try parsing again
|
# Try parsing again
|
||||||
|
|
@ -308,14 +283,12 @@ class BaseModelClient:
|
||||||
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
|
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
|
||||||
f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to remove inline comments (// style)
|
# Try to remove inline comments (// style)
|
||||||
try:
|
try:
|
||||||
# Remove // comments from each line
|
# Remove // comments from each line
|
||||||
lines = json_text.split('\n')
|
lines = json_text.split("\n")
|
||||||
cleaned_lines = []
|
cleaned_lines = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
# Find // that's not inside quotes
|
# Find // that's not inside quotes
|
||||||
|
|
@ -326,12 +299,12 @@ class BaseModelClient:
|
||||||
if escape_next:
|
if escape_next:
|
||||||
escape_next = False
|
escape_next = False
|
||||||
continue
|
continue
|
||||||
if char == '\\':
|
if char == "\\":
|
||||||
escape_next = True
|
escape_next = True
|
||||||
continue
|
continue
|
||||||
if char == '"' and not escape_next:
|
if char == '"' and not escape_next:
|
||||||
in_quotes = not in_quotes
|
in_quotes = not in_quotes
|
||||||
if not in_quotes and line[i:i+2] == '//':
|
if not in_quotes and line[i : i + 2] == "//":
|
||||||
comment_pos = i
|
comment_pos = i
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -342,17 +315,15 @@ class BaseModelClient:
|
||||||
cleaned_line = line
|
cleaned_line = line
|
||||||
cleaned_lines.append(cleaned_line)
|
cleaned_lines.append(cleaned_line)
|
||||||
|
|
||||||
comment_free_json = '\n'.join(cleaned_lines)
|
comment_free_json = "\n".join(cleaned_lines)
|
||||||
# Also remove trailing commas after comment removal
|
# Also remove trailing commas after comment removal
|
||||||
comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json)
|
comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
|
||||||
|
|
||||||
data = json.loads(comment_free_json)
|
data = json.loads(comment_free_json)
|
||||||
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
|
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
|
||||||
return data.get("orders", None)
|
return data.get("orders", None)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
|
||||||
f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
# 3b) Attempt bracket fallback: we look for the substring after "orders"
|
||||||
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
|
||||||
|
|
@ -366,23 +337,19 @@ class BaseModelClient:
|
||||||
if isinstance(moves, list):
|
if isinstance(moves, list):
|
||||||
return moves
|
return moves
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
|
||||||
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If all attempts failed
|
# If all attempts failed
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _validate_orders(
|
def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
||||||
self, moves: List[str], possible_orders: Dict[str, List[str]]
|
|
||||||
) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
|
|
||||||
"""
|
"""
|
||||||
Filter out invalid moves, fill missing with HOLD, else fallback.
|
Filter out invalid moves, fill missing with HOLD, else fallback.
|
||||||
Returns a tuple: (validated_moves, invalid_moves_found)
|
Returns a tuple: (validated_moves, invalid_moves_found)
|
||||||
"""
|
"""
|
||||||
logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}")
|
logger.debug(f"[{self.model_name}] Proposed LLM moves: {moves}")
|
||||||
validated = []
|
validated = []
|
||||||
invalid_moves_found = [] # ADDED: To collect invalid moves
|
invalid_moves_found = [] # ADDED: To collect invalid moves
|
||||||
used_locs = set()
|
used_locs = set()
|
||||||
|
|
||||||
if not isinstance(moves, list):
|
if not isinstance(moves, list):
|
||||||
|
|
@ -399,21 +366,21 @@ class BaseModelClient:
|
||||||
used_locs.add(parts[1][:3])
|
used_locs.add(parts[1][:3])
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}")
|
logger.debug(f"[{self.model_name}] Invalid move from LLM: {move_str}")
|
||||||
invalid_moves_found.append(move_str) # ADDED: Collect invalid move
|
invalid_moves_found.append(move_str) # ADDED: Collect invalid move
|
||||||
|
|
||||||
# Fill missing with hold
|
# Fill missing with hold
|
||||||
for loc, orders_list in possible_orders.items():
|
for loc, orders_list in possible_orders.items():
|
||||||
if loc not in used_locs and orders_list:
|
if loc not in used_locs and orders_list:
|
||||||
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
hold_candidates = [o for o in orders_list if o.endswith("H")]
|
||||||
validated.append(
|
validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
|
||||||
hold_candidates[0] if hold_candidates else orders_list[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
|
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
|
||||||
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
|
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
|
||||||
return self.fallback_orders(possible_orders), []
|
return self.fallback_orders(possible_orders), []
|
||||||
elif not validated and invalid_moves_found: # All LLM moves were invalid
|
elif not validated and invalid_moves_found: # All LLM moves were invalid
|
||||||
logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}")
|
logger.warning(
|
||||||
|
f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
|
||||||
|
)
|
||||||
# We return empty list for validated, but the invalid_moves_found list is populated
|
# We return empty list for validated, but the invalid_moves_found list is populated
|
||||||
return self.fallback_orders(possible_orders), invalid_moves_found
|
return self.fallback_orders(possible_orders), invalid_moves_found
|
||||||
|
|
||||||
|
|
@ -442,9 +409,8 @@ class BaseModelClient:
|
||||||
# log_file_path: str, # Not used directly by build_context_prompt
|
# log_file_path: str, # Not used directly by build_context_prompt
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
|
||||||
|
|
||||||
context = self.build_context_prompt(
|
context = self.build_context_prompt(
|
||||||
|
|
@ -455,7 +421,7 @@ class BaseModelClient:
|
||||||
game_history,
|
game_history,
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
prompts_dir=self.prompts_dir,
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -472,7 +438,7 @@ class BaseModelClient:
|
||||||
# log_file_path: str, # Not used directly by build_context_prompt
|
# log_file_path: str, # Not used directly by build_context_prompt
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
# MINIMAL CHANGE: Just change to load unformatted version conditionally
|
# MINIMAL CHANGE: Just change to load unformatted version conditionally
|
||||||
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
|
instructions = load_prompt(get_prompt_path("conversation_instructions.txt"), prompts_dir=self.prompts_dir)
|
||||||
|
|
@ -486,7 +452,7 @@ class BaseModelClient:
|
||||||
game_history,
|
game_history,
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
prompts_dir=self.prompts_dir,
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -497,7 +463,7 @@ class BaseModelClient:
|
||||||
logger.info(f"[{power_name}] Found {len(recent_messages_to_power)} high priority messages to respond to")
|
logger.info(f"[{power_name}] Found {len(recent_messages_to_power)} high priority messages to respond to")
|
||||||
if recent_messages_to_power:
|
if recent_messages_to_power:
|
||||||
for i, msg in enumerate(recent_messages_to_power):
|
for i, msg in enumerate(recent_messages_to_power):
|
||||||
logger.info(f"[{power_name}] Priority message {i+1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...")
|
logger.info(f"[{power_name}] Priority message {i + 1}: From {msg['sender']} in {msg['phase']}: {msg['content'][:50]}...")
|
||||||
|
|
||||||
# KEEP ORIGINAL: Add a section for unanswered messages
|
# KEEP ORIGINAL: Add a section for unanswered messages
|
||||||
unanswered_messages = "\n\nRECENT MESSAGES REQUIRING YOUR ATTENTION:\n"
|
unanswered_messages = "\n\nRECENT MESSAGES REQUIRING YOUR ATTENTION:\n"
|
||||||
|
|
@ -508,23 +474,30 @@ class BaseModelClient:
|
||||||
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
|
||||||
|
|
||||||
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
final_prompt = context + unanswered_messages + "\n\n" + instructions
|
||||||
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
|
final_prompt = (
|
||||||
|
final_prompt.replace("AUSTRIA", "Austria")
|
||||||
|
.replace("ENGLAND", "England")
|
||||||
|
.replace("FRANCE", "France")
|
||||||
|
.replace("GERMANY", "Germany")
|
||||||
|
.replace("ITALY", "Italy")
|
||||||
|
.replace("RUSSIA", "Russia")
|
||||||
|
.replace("TURKEY", "Turkey")
|
||||||
|
)
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
||||||
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
|
||||||
self,
|
self,
|
||||||
game,
|
game,
|
||||||
board_state,
|
board_state,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
possible_orders: Dict[str, List[str]],
|
possible_orders: Dict[str, List[str]],
|
||||||
game_history: GameHistory,
|
game_history: GameHistory,
|
||||||
game_phase: str, # Used for logging
|
game_phase: str, # Used for logging
|
||||||
log_file_path: str, # Used for logging
|
log_file_path: str, # Used for logging
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
prompt = self.build_planning_prompt(
|
prompt = self.build_planning_prompt(
|
||||||
game,
|
game,
|
||||||
board_state,
|
board_state,
|
||||||
|
|
@ -535,7 +508,7 @@ class BaseModelClient:
|
||||||
# log_file_path, # Not passed to build_planning_prompt directly
|
# log_file_path, # Not passed to build_planning_prompt directly
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary_str=agent_private_diary_str, # Pass diary string
|
agent_private_diary_str=agent_private_diary_str, # Pass diary string
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call LLM using the logging wrapper
|
# Call LLM using the logging wrapper
|
||||||
|
|
@ -544,8 +517,8 @@ class BaseModelClient:
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase, # Use game_phase for logging
|
phase=game_phase, # Use game_phase for logging
|
||||||
response_type='plan_reply', # Changed from 'plan' to avoid confusion
|
response_type="plan_reply", # Changed from 'plan' to avoid confusion
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
|
||||||
return raw_response
|
return raw_response
|
||||||
|
|
@ -567,10 +540,10 @@ class BaseModelClient:
|
||||||
"""
|
"""
|
||||||
Generates a negotiation message, considering agent state.
|
Generates a negotiation message, considering agent state.
|
||||||
"""
|
"""
|
||||||
raw_input_prompt = "" # Initialize for finally block
|
raw_input_prompt = "" # Initialize for finally block
|
||||||
raw_response = "" # Initialize for finally block
|
raw_response = "" # Initialize for finally block
|
||||||
success_status = "Failure: Initialized" # Default status
|
success_status = "Failure: Initialized" # Default status
|
||||||
messages_to_return = [] # Initialize to ensure it's defined
|
messages_to_return = [] # Initialize to ensure it's defined
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_input_prompt = self.build_conversation_prompt(
|
raw_input_prompt = self.build_conversation_prompt(
|
||||||
|
|
@ -592,7 +565,7 @@ class BaseModelClient:
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game_phase,
|
phase=game_phase,
|
||||||
response_type='negotiation', # For run_llm_and_log's internal context
|
response_type="negotiation", # For run_llm_and_log's internal context
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
|
||||||
|
|
||||||
|
|
@ -600,13 +573,10 @@ class BaseModelClient:
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Local import to avoid circular dependency
|
# Local import to avoid circular dependency
|
||||||
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
|
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
|
||||||
|
|
||||||
# Format the natural language response into structured JSON
|
# Format the natural language response into structured JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
raw_response,
|
raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
|
||||||
FORMAT_CONVERSATION,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=game_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -632,10 +602,10 @@ class BaseModelClient:
|
||||||
# Original parsing logic as fallback
|
# Original parsing logic as fallback
|
||||||
if not parsed_messages:
|
if not parsed_messages:
|
||||||
# Attempt to find blocks enclosed in {{...}}
|
# Attempt to find blocks enclosed in {{...}}
|
||||||
double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL)
|
double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
|
||||||
if double_brace_blocks:
|
if double_brace_blocks:
|
||||||
# If {{...}} blocks are found, assume each is a self-contained JSON object
|
# If {{...}} blocks are found, assume each is a self-contained JSON object
|
||||||
json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks])
|
json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
|
||||||
else:
|
else:
|
||||||
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
|
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
|
||||||
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
|
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
|
||||||
|
|
@ -650,10 +620,10 @@ class BaseModelClient:
|
||||||
json_blocks = [json.dumps(data)]
|
json_blocks = [json.dumps(data)]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If parsing the whole block fails, fall back to regex for individual objects
|
# If parsing the whole block fails, fall back to regex for individual objects
|
||||||
json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL)
|
json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
|
||||||
else:
|
else:
|
||||||
# If no markdown block, fall back to regex for any JSON object in the response
|
# If no markdown block, fall back to regex for any JSON object in the response
|
||||||
json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL)
|
json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
|
||||||
|
|
||||||
# Process json_blocks if we have them from fallback parsing
|
# Process json_blocks if we have them from fallback parsing
|
||||||
if not parsed_messages and json_blocks:
|
if not parsed_messages and json_blocks:
|
||||||
|
|
@ -661,7 +631,7 @@ class BaseModelClient:
|
||||||
try:
|
try:
|
||||||
cleaned_block = block.strip()
|
cleaned_block = block.strip()
|
||||||
# Attempt to fix common JSON issues like trailing commas before parsing
|
# Attempt to fix common JSON issues like trailing commas before parsing
|
||||||
cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block)
|
cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
|
||||||
parsed_message = json.loads(cleaned_block)
|
parsed_message = json.loads(cleaned_block)
|
||||||
parsed_messages.append(parsed_message)
|
parsed_messages.append(parsed_message)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|
@ -699,7 +669,7 @@ class BaseModelClient:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.model_name}] Error in get_conversation_reply for {power_name}: {e}", exc_info=True)
|
logger.error(f"[{self.model_name}] Error in get_conversation_reply for {power_name}: {e}", exc_info=True)
|
||||||
success_status = f"Failure: Exception ({type(e).__name__})"
|
success_status = f"Failure: Exception ({type(e).__name__})"
|
||||||
messages_to_return = [] # Ensure empty list on general exception
|
messages_to_return = [] # Ensure empty list on general exception
|
||||||
finally:
|
finally:
|
||||||
if log_file_path:
|
if log_file_path:
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
|
|
@ -710,11 +680,11 @@ class BaseModelClient:
|
||||||
response_type="negotiation_message",
|
response_type="negotiation_message",
|
||||||
raw_input_prompt=raw_input_prompt,
|
raw_input_prompt=raw_input_prompt,
|
||||||
raw_response=raw_response,
|
raw_response=raw_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
return messages_to_return
|
return messages_to_return
|
||||||
|
|
||||||
async def get_plan( # This is the original get_plan, now distinct from get_planning_reply
|
async def get_plan( # This is the original get_plan, now distinct from get_planning_reply
|
||||||
self,
|
self,
|
||||||
game,
|
game,
|
||||||
board_state,
|
board_state,
|
||||||
|
|
@ -724,7 +694,7 @@ class BaseModelClient:
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
agent_goals: Optional[List[str]] = None,
|
agent_goals: Optional[List[str]] = None,
|
||||||
agent_relationships: Optional[Dict[str, str]] = None,
|
agent_relationships: Optional[Dict[str, str]] = None,
|
||||||
agent_private_diary_str: Optional[str] = None, # Added
|
agent_private_diary_str: Optional[str] = None, # Added
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a strategic plan for the given power based on the current state.
|
Generates a strategic plan for the given power based on the current state.
|
||||||
|
|
@ -740,7 +710,7 @@ class BaseModelClient:
|
||||||
# For planning, possible_orders might be less critical for the context,
|
# For planning, possible_orders might be less critical for the context,
|
||||||
# but build_context_prompt expects it. We can pass an empty dict or calculate it.
|
# but build_context_prompt expects it. We can pass an empty dict or calculate it.
|
||||||
# For simplicity, let's pass empty if not strictly needed by context for planning.
|
# For simplicity, let's pass empty if not strictly needed by context for planning.
|
||||||
possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context
|
possible_orders_for_context = {} # game.get_all_possible_orders() if needed by context
|
||||||
|
|
||||||
context_prompt = self.build_context_prompt(
|
context_prompt = self.build_context_prompt(
|
||||||
game,
|
game,
|
||||||
|
|
@ -750,7 +720,7 @@ class BaseModelClient:
|
||||||
game_history,
|
game_history,
|
||||||
agent_goals=agent_goals,
|
agent_goals=agent_goals,
|
||||||
agent_relationships=agent_relationships,
|
agent_relationships=agent_relationships,
|
||||||
agent_private_diary=agent_private_diary_str, # Pass diary string
|
agent_private_diary=agent_private_diary_str, # Pass diary string
|
||||||
prompts_dir=self.prompts_dir,
|
prompts_dir=self.prompts_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -765,12 +735,12 @@ class BaseModelClient:
|
||||||
try:
|
try:
|
||||||
# Use run_llm_and_log for the actual LLM call
|
# Use run_llm_and_log for the actual LLM call
|
||||||
raw_plan_response = await run_llm_and_log(
|
raw_plan_response = await run_llm_and_log(
|
||||||
client=self, # Pass self (the client instance)
|
client=self, # Pass self (the client instance)
|
||||||
prompt=full_prompt,
|
prompt=full_prompt,
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type='plan_generation', # More specific type for run_llm_and_log context
|
response_type="plan_generation", # More specific type for run_llm_and_log context
|
||||||
)
|
)
|
||||||
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
|
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
|
||||||
# No parsing needed for the plan, return the raw string
|
# No parsing needed for the plan, return the raw string
|
||||||
|
|
@ -781,16 +751,16 @@ class BaseModelClient:
|
||||||
success_status = f"Failure: Exception ({type(e).__name__})"
|
success_status = f"Failure: Exception ({type(e).__name__})"
|
||||||
plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}"
|
plan_to_return = f"Error: Failed to generate plan for {power_name} due to exception: {e}"
|
||||||
finally:
|
finally:
|
||||||
if log_file_path: # Only log if a path is provided
|
if log_file_path: # Only log if a path is provided
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=game.current_short_phase if game else "UnknownPhase",
|
phase=game.current_short_phase if game else "UnknownPhase",
|
||||||
response_type="plan_generation", # Specific type for CSV logging
|
response_type="plan_generation", # Specific type for CSV logging
|
||||||
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
|
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
|
||||||
raw_response=raw_plan_response,
|
raw_response=raw_plan_response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
# token_usage and cost can be added later
|
# token_usage and cost can be added later
|
||||||
)
|
)
|
||||||
return plan_to_return
|
return plan_to_return
|
||||||
|
|
@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
|
||||||
):
|
):
|
||||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
|
|
||||||
self.base_url = (
|
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
|
||||||
base_url
|
|
||||||
or os.environ.get("OPENAI_BASE_URL")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
|
|
@ -832,18 +798,14 @@ class OpenAIClient(BaseModelClient):
|
||||||
inject_random_seed: bool = True,
|
inject_random_seed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
system_prompt_content = (
|
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
|
||||||
if inject_random_seed
|
|
||||||
else self.system_prompt
|
|
||||||
)
|
|
||||||
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_prompt_content},
|
{"role": "system", "content": system_prompt_content},
|
||||||
{"role": "user", "content": prompt_with_cta},
|
{"role": "user", "content": prompt_with_cta},
|
||||||
],
|
],
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
|
@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
if not response.content:
|
if not response.content:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
|
||||||
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
return response.content[0].text.strip() if response.content else ""
|
return response.content[0].text.strip() if response.content else ""
|
||||||
except json.JSONDecodeError as json_err:
|
except json.JSONDecodeError as json_err:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
|
||||||
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
|
||||||
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_config = genai.types.GenerationConfig(
|
generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
|
||||||
temperature=temperature,
|
|
||||||
max_output_tokens=self.max_tokens
|
|
||||||
)
|
|
||||||
response = await self.client.generate_content_async(
|
response = await self.client.generate_content_async(
|
||||||
contents=full_prompt,
|
contents=full_prompt,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response or not response.text:
|
if not response or not response.text:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
|
||||||
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
|
||||||
super().__init__(model_name, prompts_dir=prompts_dir)
|
super().__init__(model_name, prompts_dir=prompts_dir)
|
||||||
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||||
self.client = AsyncDeepSeekOpenAI(
|
self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
|
||||||
api_key=self.api_key,
|
|
||||||
base_url="https://api.deepseek.com/"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
|
||||||
try:
|
try:
|
||||||
|
|
@ -986,9 +934,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
|
||||||
|
|
||||||
if not response or not response.choices:
|
if not response or not response.choices:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] No valid response in generate_response.")
|
||||||
f"[{self.model_name}] No valid response in generate_response."
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content = response.choices[0].message.content.strip()
|
content = response.choices[0].message.content.strip()
|
||||||
|
|
@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1037,19 +981,14 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make the API call using aiohttp
|
# Make the API call using aiohttp
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
async with session.post(self.base_url, json=payload, headers=headers) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
|
||||||
f"[{self.model_name}] API error (status {response.status}): {error_text}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
|
|
@ -1059,24 +998,18 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
try:
|
try:
|
||||||
outputs = response_data.get("output", [])
|
outputs = response_data.get("output", [])
|
||||||
if len(outputs) < 2:
|
if len(outputs) < 2:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
|
||||||
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# The message is typically in the second output item
|
# The message is typically in the second output item
|
||||||
message_output = outputs[1]
|
message_output = outputs[1]
|
||||||
if message_output.get("type") != "message":
|
if message_output.get("type") != "message":
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
|
||||||
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content_list = message_output.get("content", [])
|
content_list = message_output.get("content", [])
|
||||||
if not content_list:
|
if not content_list:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] Empty content list in message output")
|
||||||
f"[{self.model_name}] Empty content list in message output"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Look for the content item with type 'output_text'
|
# Look for the content item with type 'output_text'
|
||||||
|
|
@ -1087,28 +1020,20 @@ class OpenAIResponsesClient(BaseModelClient):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not text_content:
|
if not text_content:
|
||||||
logger.warning(
|
logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
|
||||||
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return text_content.strip()
|
return text_content.strip()
|
||||||
|
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
|
||||||
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
|
||||||
f"[{self.model_name}] HTTP client error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
|
||||||
f"[{self.model_name}] Unexpected error in generate_response: {e}"
|
|
||||||
)
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1129,10 +1054,7 @@ class OpenRouterClient(BaseModelClient):
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
raise ValueError("OPENROUTER_API_KEY environment variable is required")
|
||||||
|
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=self.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
|
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
|
||||||
|
|
||||||
|
|
@ -1150,10 +1072,7 @@ class OpenRouterClient(BaseModelClient):
|
||||||
# Prepare standard OpenAI-compatible request
|
# Prepare standard OpenAI-compatible request
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[
|
messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
|
||||||
{"role": "system", "content": system_prompt_content},
|
|
||||||
{"role": "user", "content": prompt_with_cta}
|
|
||||||
],
|
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
@ -1236,16 +1155,16 @@ class TogetherAIClient(BaseModelClient):
|
||||||
return ""
|
return ""
|
||||||
except TogetherAPIError as e:
|
except TogetherAPIError as e:
|
||||||
logger.error(f"[{self.model_name}] Together AI API error: {e}", exc_info=True)
|
logger.error(f"[{self.model_name}] Together AI API error: {e}", exc_info=True)
|
||||||
return f"Error: Together AI API error - {str(e)}" # Return a string with error info
|
return f"Error: Together AI API error - {str(e)}" # Return a string with error info
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.model_name}] Unexpected error in TogetherAIClient: {e}", exc_info=True)
|
logger.error(f"[{self.model_name}] Unexpected error in TogetherAIClient: {e}", exc_info=True)
|
||||||
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
|
return f"Error: Unexpected error - {str(e)}" # Return a string with error info
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
# RequestsOpenAIClient – sync requests, wrapped async (original + api_key)
|
||||||
##############################################################################
|
##############################################################################
|
||||||
import requests, asyncio
|
|
||||||
|
|
||||||
class RequestsOpenAIClient(BaseModelClient):
|
class RequestsOpenAIClient(BaseModelClient):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1266,18 +1185,14 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
|
||||||
|
|
||||||
self.base_url = (
|
self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
|
||||||
base_url
|
|
||||||
or os.environ.get("OPENAI_BASE_URL")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
).rstrip("/")
|
|
||||||
|
|
||||||
self.endpoint = f"{self.base_url}/chat/completions"
|
self.endpoint = f"{self.base_url}/chat/completions"
|
||||||
|
|
||||||
# ---------------- internal blocking helper ---------------- #
|
# ---------------- internal blocking helper ---------------- #
|
||||||
def _post_sync(self, payload: dict) -> dict:
|
def _post_sync(self, payload: dict) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
}
|
}
|
||||||
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
|
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=60)
|
||||||
|
|
@ -1291,17 +1206,13 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
inject_random_seed: bool = True,
|
inject_random_seed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
system_prompt_content = (
|
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
|
||||||
f"{generate_random_seed()}\n\n{self.system_prompt}"
|
|
||||||
if inject_random_seed
|
|
||||||
else self.system_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system_prompt_content},
|
{"role": "system", "content": system_prompt_content},
|
||||||
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
{"role": "user", "content": f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"},
|
||||||
],
|
],
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
|
@ -1322,15 +1233,15 @@ class RequestsOpenAIClient(BaseModelClient):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 3) Factory to Load Model Client
|
# 3) Factory to Load Model Client
|
||||||
##############################################################################
|
##############################################################################
|
||||||
class ModelSpec(NamedTuple):
|
class ModelSpec(NamedTuple):
|
||||||
prefix: Optional[str] # 'openai', 'requests', …
|
prefix: Optional[str] # 'openai', 'requests', …
|
||||||
model: str # 'gpt-4o'
|
model: str # 'gpt-4o'
|
||||||
base: Optional[str] # 'https://proxy.foo'
|
base: Optional[str] # 'https://proxy.foo'
|
||||||
key: Optional[str] # 'sk-…' (may be None)
|
key: Optional[str] # 'sk-…' (may be None)
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_spec(raw: str) -> ModelSpec:
|
def _parse_model_spec(raw: str) -> ModelSpec:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1339,11 +1250,11 @@ def _parse_model_spec(raw: str) -> ModelSpec:
|
||||||
"""
|
"""
|
||||||
raw = raw.strip()
|
raw = raw.strip()
|
||||||
|
|
||||||
pre_hash, _, key_part = raw.partition("#")
|
pre_hash, _, key_part = raw.partition("#")
|
||||||
pre_at, _, base_part = pre_hash.partition("@")
|
pre_at, _, base_part = pre_hash.partition("@")
|
||||||
|
|
||||||
maybe_pref, sep, model_part = pre_at.partition(":")
|
maybe_pref, sep, model_part = pre_at.partition(":")
|
||||||
if sep: # explicit prefix was present
|
if sep: # explicit prefix was present
|
||||||
prefix, model = maybe_pref.lower(), model_part
|
prefix, model = maybe_pref.lower(), model_part
|
||||||
else:
|
else:
|
||||||
prefix, model = None, maybe_pref
|
prefix, model = None, maybe_pref
|
||||||
|
|
@ -1379,17 +1290,17 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
||||||
match spec.prefix:
|
match spec.prefix:
|
||||||
case "openai" | "oai":
|
case "openai" | "oai":
|
||||||
return OpenAIClient(
|
return OpenAIClient(
|
||||||
model_name = spec.model,
|
model_name=spec.model,
|
||||||
prompts_dir = prompts_dir,
|
prompts_dir=prompts_dir,
|
||||||
base_url = spec.base,
|
base_url=spec.base,
|
||||||
api_key = inline_key,
|
api_key=inline_key,
|
||||||
)
|
)
|
||||||
case "requests" | "req":
|
case "requests" | "req":
|
||||||
return RequestsOpenAIClient(
|
return RequestsOpenAIClient(
|
||||||
model_name = spec.model,
|
model_name=spec.model,
|
||||||
prompts_dir = prompts_dir,
|
prompts_dir=prompts_dir,
|
||||||
base_url = spec.base,
|
base_url=spec.base,
|
||||||
api_key = inline_key,
|
api_key=inline_key,
|
||||||
)
|
)
|
||||||
case "responses" | "oai-resp" | "openai-responses":
|
case "responses" | "oai-resp" | "openai-responses":
|
||||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||||
|
|
@ -1409,7 +1320,7 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# 2. Heuristic fallback path (identical to the original behaviour) #
|
# 2. Heuristic fallback path (identical to the original behaviour) #
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
lower_id = spec.model.lower()
|
lower_id = spec.model.lower()
|
||||||
|
|
||||||
if lower_id == "o3-pro":
|
if lower_id == "o3-pro":
|
||||||
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
return OpenAIResponsesClient(spec.model, prompts_dir, api_key=inline_key)
|
||||||
|
|
@ -1432,17 +1343,15 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
|
||||||
|
|
||||||
# Default: OpenAI-compatible async client
|
# Default: OpenAI-compatible async client
|
||||||
return OpenAIClient(
|
return OpenAIClient(
|
||||||
model_name = spec.model,
|
model_name=spec.model,
|
||||||
prompts_dir = prompts_dir,
|
prompts_dir=prompts_dir,
|
||||||
base_url = spec.base,
|
base_url=spec.base,
|
||||||
api_key = inline_key,
|
api_key=inline_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIClient(model_name, prompts_dir, base_url)
|
return OpenAIClient(model_name, prompts_dir, base_url)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
|
||||||
visible = []
|
visible = []
|
||||||
for msg in conversation_messages:
|
for msg in conversation_messages:
|
||||||
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
|
||||||
if (
|
if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
|
||||||
msg["recipient"] == "ALL"
|
|
||||||
or msg["recipient"] == "GLOBAL"
|
|
||||||
or msg["sender"] == power_name
|
|
||||||
or msg["recipient"] == power_name
|
|
||||||
):
|
|
||||||
visible.append(msg)
|
visible.append(msg)
|
||||||
return visible # already in chronological order if appended that way
|
return visible # already in chronological order if appended that way
|
||||||
|
|
|
||||||
|
|
@ -11,25 +11,28 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
|
||||||
"""A local copy of the helper from agent.py to avoid circular imports."""
|
"""A local copy of the helper from agent.py to avoid circular imports."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prompts_dir:
|
if prompts_dir:
|
||||||
filepath = os.path.join(prompts_dir, filename)
|
filepath = os.path.join(prompts_dir, filename)
|
||||||
else:
|
else:
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
default_prompts_dir = os.path.join(current_dir, 'prompts')
|
default_prompts_dir = os.path.join(current_dir, "prompts")
|
||||||
filepath = os.path.join(default_prompts_dir, filename)
|
filepath = os.path.join(default_prompts_dir, filename)
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading prompt file {filepath}: {e}")
|
logger.error(f"Error loading prompt file {filepath}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def run_diary_consolidation(
|
async def run_diary_consolidation(
|
||||||
agent: 'DiplomacyAgent',
|
agent: "DiplomacyAgent",
|
||||||
game: "Game",
|
game: "Game",
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
entries_to_keep_unsummarized: int = 6,
|
entries_to_keep_unsummarized: int = 6,
|
||||||
|
|
@ -39,70 +42,42 @@ async def run_diary_consolidation(
|
||||||
Consolidate older diary entries while keeping recent ones.
|
Consolidate older diary entries while keeping recent ones.
|
||||||
This is the logic moved from the DiplomacyAgent class.
|
This is the logic moved from the DiplomacyAgent class.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
|
||||||
f"[{agent.power_name}] CONSOLIDATION START — "
|
|
||||||
f"{len(agent.full_private_diary)} total full entries"
|
|
||||||
)
|
|
||||||
|
|
||||||
full_entries = [
|
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
|
||||||
e for e in agent.full_private_diary
|
|
||||||
if not e.startswith("[CONSOLIDATED HISTORY]")
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(full_entries) <= entries_to_keep_unsummarized:
|
if len(full_entries) <= entries_to_keep_unsummarized:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
|
||||||
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
|
|
||||||
"skipping consolidation"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
boundary_entry = full_entries[-entries_to_keep_unsummarized]
|
||||||
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
|
||||||
if not match:
|
if not match:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
|
||||||
f"[{agent.power_name}] Could not parse year from boundary entry; "
|
|
||||||
"aborting consolidation"
|
|
||||||
)
|
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
return
|
return
|
||||||
|
|
||||||
cutoff_year = int(match.group(1))
|
cutoff_year = int(match.group(1))
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
|
||||||
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _entry_year(entry: str) -> int | None:
|
def _entry_year(entry: str) -> int | None:
|
||||||
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
|
||||||
return int(m.group(1)) if m else None
|
return int(m.group(1)) if m else None
|
||||||
|
|
||||||
entries_to_summarize = [
|
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
|
||||||
e for e in full_entries
|
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
|
||||||
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
|
|
||||||
]
|
|
||||||
entries_to_keep = [
|
|
||||||
e for e in full_entries
|
|
||||||
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
|
||||||
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
|
|
||||||
f"keeping {len(entries_to_keep)} recent entries verbatim"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not entries_to_summarize:
|
if not entries_to_summarize:
|
||||||
agent.private_diary = list(agent.full_private_diary)
|
agent.private_diary = list(agent.full_private_diary)
|
||||||
logger.warning(
|
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
|
||||||
f"[{agent.power_name}] No eligible entries to summarise; "
|
|
||||||
"context diary left unchanged"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
|
||||||
if not prompt_template:
|
if not prompt_template:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
|
||||||
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
|
|
@ -132,23 +107,14 @@ async def run_diary_consolidation(
|
||||||
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
|
||||||
agent.private_diary = [new_summary_entry] + entries_to_keep
|
agent.private_diary = [new_summary_entry] + entries_to_keep
|
||||||
success_flag = "TRUE"
|
success_flag = "TRUE"
|
||||||
logger.info(
|
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
|
||||||
f"[{agent.power_name}] Consolidation complete — "
|
|
||||||
f"{len(agent.private_diary)} context entries now"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
|
||||||
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=(
|
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
|
||||||
consolidation_client.model_name
|
|
||||||
if consolidation_client is not None
|
|
||||||
else agent.client.model_name
|
|
||||||
),
|
|
||||||
power_name=agent.power_name,
|
power_name=agent.power_name,
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
response_type="diary_consolidation",
|
response_type="diary_consolidation",
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
|
||||||
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
from typing import Optional
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Import logging function and model configuration
|
# Import logging function and model configuration
|
||||||
from .utils import log_llm_response, get_special_models
|
from .utils import log_llm_response, get_special_models
|
||||||
|
|
||||||
# Import client loading function
|
# Import client loading function
|
||||||
from .clients import load_model_client
|
from .clients import load_model_client
|
||||||
|
|
||||||
|
|
@ -26,11 +25,7 @@ FORMAT_ORDER_DIARY = "order_diary"
|
||||||
|
|
||||||
|
|
||||||
async def format_with_gemini_flash(
|
async def format_with_gemini_flash(
|
||||||
raw_response: str,
|
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
|
||||||
format_type: str,
|
|
||||||
power_name: Optional[str] = None,
|
|
||||||
phase: Optional[str] = None,
|
|
||||||
log_file_path: Optional[str] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a natural language response into required JSON structure using Gemini Flash.
|
Format a natural language response into required JSON structure using Gemini Flash.
|
||||||
|
|
@ -52,7 +47,7 @@ async def format_with_gemini_flash(
|
||||||
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
|
||||||
FORMAT_ORDERS: "formatting/format_orders.txt",
|
FORMAT_ORDERS: "formatting/format_orders.txt",
|
||||||
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
|
||||||
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
|
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
|
||||||
}
|
}
|
||||||
|
|
||||||
if format_type not in format_prompts:
|
if format_type not in format_prompts:
|
||||||
|
|
@ -63,7 +58,7 @@ async def format_with_gemini_flash(
|
||||||
if not prompt_file.exists():
|
if not prompt_file.exists():
|
||||||
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
|
||||||
|
|
||||||
with open(prompt_file, 'r') as f:
|
with open(prompt_file, "r") as f:
|
||||||
format_prompt = f.read()
|
format_prompt = f.read()
|
||||||
|
|
||||||
# Replace placeholder with actual response
|
# Replace placeholder with actual response
|
||||||
|
|
@ -87,7 +82,7 @@ async def format_with_gemini_flash(
|
||||||
formatted_response = await formatter_client.generate_response(
|
formatted_response = await formatter_client.generate_response(
|
||||||
prompt=format_prompt,
|
prompt=format_prompt,
|
||||||
temperature=0, # Deterministic formatting
|
temperature=0, # Deterministic formatting
|
||||||
inject_random_seed=False # No need for random seed in formatting
|
inject_random_seed=False, # No need for random seed in formatting
|
||||||
)
|
)
|
||||||
|
|
||||||
if not formatted_response:
|
if not formatted_response:
|
||||||
|
|
@ -116,7 +111,7 @@ async def format_with_gemini_flash(
|
||||||
response_type=f"format_{format_type}",
|
response_type=f"format_{format_type}",
|
||||||
raw_input_prompt=format_prompt,
|
raw_input_prompt=format_prompt,
|
||||||
raw_response=formatted_response,
|
raw_response=formatted_response,
|
||||||
success="Success"
|
success="Success",
|
||||||
)
|
)
|
||||||
|
|
||||||
return formatted_response
|
return formatted_response
|
||||||
|
|
@ -134,7 +129,7 @@ async def format_with_gemini_flash(
|
||||||
response_type=f"format_{format_type}",
|
response_type=f"format_{format_type}",
|
||||||
raw_input_prompt=format_prompt,
|
raw_input_prompt=format_prompt,
|
||||||
raw_response=f"ERROR: {str(e)}",
|
raw_response=f"ERROR: {str(e)}",
|
||||||
success=f"Failure: {type(e).__name__}"
|
success=f"Failure: {type(e).__name__}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return empty structure based on format type
|
# Return empty structure based on format type
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
logger = logging.getLogger("utils")
|
logger = logging.getLogger("utils")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -24,12 +23,8 @@ class Phase:
|
||||||
plans: Dict[str, str] = field(default_factory=dict)
|
plans: Dict[str, str] = field(default_factory=dict)
|
||||||
messages: List[Message] = field(default_factory=list)
|
messages: List[Message] = field(default_factory=list)
|
||||||
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
|
||||||
orders_by_power: Dict[str, List[str]] = field(
|
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
|
||||||
default_factory=lambda: defaultdict(list)
|
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
|
||||||
)
|
|
||||||
results_by_power: Dict[str, List[List[str]]] = field(
|
|
||||||
default_factory=lambda: defaultdict(list)
|
|
||||||
)
|
|
||||||
# NEW: Store phase-end summaries provided by each power
|
# NEW: Store phase-end summaries provided by each power
|
||||||
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
phase_summaries: Dict[str, str] = field(default_factory=dict)
|
||||||
# NEW: Store experience/journal updates from each power for this phase
|
# NEW: Store experience/journal updates from each power for this phase
|
||||||
|
|
@ -39,9 +34,7 @@ class Phase:
|
||||||
self.plans[power_name] = plan
|
self.plans[power_name] = plan
|
||||||
|
|
||||||
def add_message(self, sender: str, recipient: str, content: str):
|
def add_message(self, sender: str, recipient: str, content: str):
|
||||||
self.messages.append(
|
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
|
||||||
Message(sender=sender, recipient=recipient, content=content)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
|
||||||
self.orders_by_power[power].extend(orders)
|
self.orders_by_power[power].extend(orders)
|
||||||
|
|
@ -110,14 +103,10 @@ class GameHistory:
|
||||||
phase.plans[power_name] = plan
|
phase.plans[power_name] = plan
|
||||||
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
logger.debug(f"Added plan for {power_name} in {phase_name}")
|
||||||
|
|
||||||
def add_message(
|
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
|
||||||
self, phase_name: str, sender: str, recipient: str, message_content: str
|
|
||||||
):
|
|
||||||
phase = self._get_phase(phase_name)
|
phase = self._get_phase(phase_name)
|
||||||
if phase:
|
if phase:
|
||||||
message = Message(
|
message = Message(sender=sender, recipient=recipient, content=message_content)
|
||||||
sender=sender, recipient=recipient, content=message_content
|
|
||||||
)
|
|
||||||
phase.messages.append(message)
|
phase.messages.append(message)
|
||||||
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
|
||||||
|
|
||||||
|
|
@ -153,10 +142,6 @@ class GameHistory:
|
||||||
return {}
|
return {}
|
||||||
return self.phases[-1].plans
|
return self.phases[-1].plans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_order_history_for_prompt(
|
def get_order_history_for_prompt(
|
||||||
self,
|
self,
|
||||||
game: "Game",
|
game: "Game",
|
||||||
|
|
@ -194,10 +179,7 @@ class GameHistory:
|
||||||
return str(tag).strip().lower()
|
return str(tag).strip().lower()
|
||||||
|
|
||||||
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
|
||||||
eng2code = {
|
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
|
||||||
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
|
|
||||||
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
|
|
||||||
}
|
|
||||||
norm = game.map.norm
|
norm = game.map.norm
|
||||||
|
|
||||||
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
|
||||||
|
|
@ -209,14 +191,14 @@ class GameHistory:
|
||||||
|
|
||||||
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
|
for pwr in sorted(set(ph.orders_by_power) | set(ph.submitted_orders_by_power)):
|
||||||
submitted = ph.submitted_orders_by_power.get(pwr, [])
|
submitted = ph.submitted_orders_by_power.get(pwr, [])
|
||||||
accepted = ph.orders_by_power.get(pwr, [])
|
accepted = ph.orders_by_power.get(pwr, [])
|
||||||
|
|
||||||
if isinstance(submitted, str):
|
if isinstance(submitted, str):
|
||||||
submitted = [submitted]
|
submitted = [submitted]
|
||||||
if isinstance(accepted, str):
|
if isinstance(accepted, str):
|
||||||
accepted = [accepted]
|
accepted = [accepted]
|
||||||
|
|
||||||
def _norm_keep(o): # keep WAIVE readable
|
def _norm_keep(o): # keep WAIVE readable
|
||||||
return o if o.upper() == "WAIVE" else norm(o)
|
return o if o.upper() == "WAIVE" else norm(o)
|
||||||
|
|
||||||
sub_norm = {_norm_keep(o): o for o in submitted}
|
sub_norm = {_norm_keep(o): o for o in submitted}
|
||||||
|
|
@ -264,13 +246,6 @@ class GameHistory:
|
||||||
return "\n(No orders were issued in recent history)\n"
|
return "\n(No orders were issued in recent history)\n"
|
||||||
return "\n".join(out_lines)
|
return "\n".join(out_lines)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
|
||||||
current_phase: Optional[Phase] = None
|
current_phase: Optional[Phase] = None
|
||||||
for phase_obj in self.phases:
|
for phase_obj in self.phases:
|
||||||
|
|
@ -325,11 +300,7 @@ class GameHistory:
|
||||||
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
|
||||||
# Skip if sender is this power (don't need to respond to own messages)
|
# Skip if sender is this power (don't need to respond to own messages)
|
||||||
if msg.sender != power_name:
|
if msg.sender != power_name:
|
||||||
messages_to_power.append({
|
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
|
||||||
'sender': msg.sender,
|
|
||||||
'content': msg.content,
|
|
||||||
'phase': phase.name
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add debug logging
|
# Add debug logging
|
||||||
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
|
||||||
|
|
@ -361,10 +332,10 @@ class GameHistory:
|
||||||
for msg in phase.messages:
|
for msg in phase.messages:
|
||||||
# Handle both Message objects and dict objects
|
# Handle both Message objects and dict objects
|
||||||
if isinstance(msg, Message):
|
if isinstance(msg, Message):
|
||||||
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
|
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
|
||||||
sender_messages.append(msg)
|
sender_messages.append(msg)
|
||||||
else: # Assume dict
|
else: # Assume dict
|
||||||
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
|
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
|
||||||
sender_messages.append(msg)
|
sender_messages.append(msg)
|
||||||
|
|
||||||
# Check for responses in this and next phases
|
# Check for responses in this and next phases
|
||||||
|
|
@ -374,25 +345,27 @@ class GameHistory:
|
||||||
recipient = msg.recipient
|
recipient = msg.recipient
|
||||||
msg_content = msg.content
|
msg_content = msg.content
|
||||||
else:
|
else:
|
||||||
recipient = msg['recipient']
|
recipient = msg["recipient"]
|
||||||
msg_content = msg['content']
|
msg_content = msg["content"]
|
||||||
|
|
||||||
# Look for responses in current phase and next phases
|
# Look for responses in current phase and next phases
|
||||||
found_response = False
|
found_response = False
|
||||||
|
|
||||||
# Check remaining phases starting from current
|
# Check remaining phases starting from current
|
||||||
for check_phase in recent_phases[i:min(i+2, len(recent_phases))]:
|
for check_phase in recent_phases[i : min(i + 2, len(recent_phases))]:
|
||||||
# Look for messages FROM the recipient TO the sender (direct response)
|
# Look for messages FROM the recipient TO the sender (direct response)
|
||||||
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
|
# or FROM the recipient to GLOBAL/ALL that might acknowledge sender
|
||||||
response_msgs = []
|
response_msgs = []
|
||||||
for m in check_phase.messages:
|
for m in check_phase.messages:
|
||||||
if isinstance(m, Message):
|
if isinstance(m, Message):
|
||||||
if m.sender == recipient and (m.recipient == sender_name or
|
if m.sender == recipient and (
|
||||||
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
|
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
|
||||||
|
):
|
||||||
response_msgs.append(m)
|
response_msgs.append(m)
|
||||||
else: # Assume dict
|
else: # Assume dict
|
||||||
if m['sender'] == recipient and (m['recipient'] == sender_name or
|
if m["sender"] == recipient and (
|
||||||
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
|
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
|
||||||
|
):
|
||||||
response_msgs.append(m)
|
response_msgs.append(m)
|
||||||
|
|
||||||
if response_msgs:
|
if response_msgs:
|
||||||
|
|
@ -402,9 +375,6 @@ class GameHistory:
|
||||||
if not found_response:
|
if not found_response:
|
||||||
if recipient not in ignored_by_power:
|
if recipient not in ignored_by_power:
|
||||||
ignored_by_power[recipient] = []
|
ignored_by_power[recipient] = []
|
||||||
ignored_by_power[recipient].append({
|
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
|
||||||
'phase': phase.name,
|
|
||||||
'content': msg_content
|
|
||||||
})
|
|
||||||
|
|
||||||
return ignored_by_power
|
return ignored_by_power
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, Tuple, Optional, Any
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
|
|
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- Serialization / Deserialization ---
|
# --- Serialization / Deserialization ---
|
||||||
|
|
||||||
|
|
||||||
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||||
"""Converts an agent object to a JSON-serializable dictionary."""
|
"""Converts an agent object to a JSON-serializable dictionary."""
|
||||||
return {
|
return {
|
||||||
|
|
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
|
||||||
"private_diary": agent.private_diary,
|
"private_diary": agent.private_diary,
|
||||||
}
|
}
|
||||||
|
|
||||||
def deserialize_agent(
|
|
||||||
agent_data: dict,
|
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
|
||||||
prompts_dir: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
override_model_id: Optional[str] = None
|
|
||||||
) -> DiplomacyAgent:
|
|
||||||
"""
|
"""
|
||||||
Recreates an agent object from a dictionary.
|
Recreates an agent object from a dictionary.
|
||||||
|
|
||||||
|
|
@ -68,34 +65,31 @@ def deserialize_agent(
|
||||||
# --- State Management ---
|
# --- State Management ---
|
||||||
|
|
||||||
# game_logic.py
|
# game_logic.py
|
||||||
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
|
||||||
|
|
||||||
|
|
||||||
def _next_phase_name(short: str) -> str:
|
def _next_phase_name(short: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return the Diplomacy phase string that chronologically follows *short*.
|
Return the Diplomacy phase string that chronologically follows *short*.
|
||||||
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
|
(E.g. S1901M → S1901R, S1901R → W1901A, W1901A → S1902M)
|
||||||
"""
|
"""
|
||||||
season = short[0] # 'S' | 'W'
|
season = short[0] # 'S' | 'W'
|
||||||
year = int(short[1:5])
|
year = int(short[1:5])
|
||||||
typ = short[-1] # 'M' | 'R' | 'A'
|
typ = short[-1] # 'M' | 'R' | 'A'
|
||||||
|
|
||||||
idx = _PHASE_ORDER.index(typ)
|
idx = _PHASE_ORDER.index(typ)
|
||||||
if idx < 2: # still in the same season
|
if idx < 2: # still in the same season
|
||||||
return f"{season}{year}{_PHASE_ORDER[idx+1]}"
|
return f"{season}{year}{_PHASE_ORDER[idx + 1]}"
|
||||||
|
|
||||||
# typ was 'A' → roll season
|
# typ was 'A' → roll season
|
||||||
if season == "S": # summer → winter, same year
|
if season == "S": # summer → winter, same year
|
||||||
return f"W{year}M"
|
return f"W{year}M"
|
||||||
else: # winter→ spring, next year
|
else: # winter→ spring, next year
|
||||||
return f"S{year+1}M"
|
return f"S{year + 1}M"
|
||||||
|
|
||||||
|
|
||||||
def save_game_state(
|
def save_game_state(
|
||||||
game: Game,
|
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
|
||||||
agents: Dict[str, DiplomacyAgent],
|
|
||||||
game_history: GameHistory,
|
|
||||||
output_path: str,
|
|
||||||
run_config: Namespace,
|
|
||||||
completed_phase_name: str
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Serialise the entire game to JSON, preserving per-phase custom metadata
|
Serialise the entire game to JSON, preserving per-phase custom metadata
|
||||||
|
|
@ -130,9 +124,7 @@ def save_game_state(
|
||||||
if extras:
|
if extras:
|
||||||
previous_phase_extras[phase["name"]] = extras
|
previous_phase_extras[phase["name"]] = extras
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
|
||||||
"Could not load previous save to retain metadata: %s", exc, exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# 2. Build the fresh base structure from the diplomacy library. #
|
# 2. Build the fresh base structure from the diplomacy library. #
|
||||||
|
|
@ -143,11 +135,7 @@ def save_game_state(
|
||||||
# 3. Walk every phase and merge the metadata back in. #
|
# 3. Walk every phase and merge the metadata back in. #
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# Capture the *current* snapshot of every live agent exactly once.
|
# Capture the *current* snapshot of every live agent exactly once.
|
||||||
current_state_agents = {
|
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
|
||||||
p_name: serialize_agent(p_agent)
|
|
||||||
for p_name, p_agent in agents.items()
|
|
||||||
if not game.powers[p_name].is_eliminated()
|
|
||||||
}
|
|
||||||
|
|
||||||
for phase_block in saved_game.get("phases", []):
|
for phase_block in saved_game.get("phases", []):
|
||||||
if int(phase_block["name"][1:5]) > run_config.max_year:
|
if int(phase_block["name"][1:5]) > run_config.max_year:
|
||||||
|
|
@ -167,31 +155,24 @@ def save_game_state(
|
||||||
|
|
||||||
# Plans for this phase – may be empty in non-movement phases.
|
# Plans for this phase – may be empty in non-movement phases.
|
||||||
phase_obj = game_history._get_phase(phase_name)
|
phase_obj = game_history._get_phase(phase_name)
|
||||||
phase_block["state_history_plans"] = (
|
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
|
||||||
phase_obj.plans if phase_obj else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
# 4. Attach top-level metadata and write atomically. #
|
# 4. Attach top-level metadata and write atomically. #
|
||||||
# -------------------------------------------------------------- #
|
# -------------------------------------------------------------- #
|
||||||
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
|
||||||
saved_game["final_agent_states"] = {
|
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
|
||||||
p_name: {"relationships": a.relationships, "goals": a.goals}
|
|
||||||
for p_name, a in agents.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Filter out phases > max_year
|
# Filter out phases > max_year
|
||||||
#saved_game["phases"] = [
|
# saved_game["phases"] = [
|
||||||
# ph for ph in saved_game["phases"]
|
# ph for ph in saved_game["phases"]
|
||||||
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
|
# if int(ph["name"][1:5]) <= run_config.max_year # <= 1902, for example
|
||||||
#]
|
# ]
|
||||||
atomic_write_json(saved_game, output_path)
|
atomic_write_json(saved_game, output_path)
|
||||||
|
|
||||||
logger.info("Game state saved successfully.")
|
logger.info("Game state saved successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_game_state(
|
def load_game_state(
|
||||||
run_dir: str,
|
run_dir: str,
|
||||||
game_file_name: str,
|
game_file_name: str,
|
||||||
|
|
@ -204,7 +185,7 @@ def load_game_state(
|
||||||
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
|
||||||
|
|
||||||
logger.info(f"Loading game state from: {game_file_path}")
|
logger.info(f"Loading game state from: {game_file_path}")
|
||||||
with open(game_file_path, 'r') as f:
|
with open(game_file_path, "r") as f:
|
||||||
saved_game_data = json.load(f)
|
saved_game_data = json.load(f)
|
||||||
|
|
||||||
# If resuming, find the specified phase and truncate the data after it
|
# If resuming, find the specified phase and truncate the data after it
|
||||||
|
|
@ -213,48 +194,48 @@ def load_game_state(
|
||||||
try:
|
try:
|
||||||
# Find the index of the phase *before* the one we want to resume from.
|
# Find the index of the phase *before* the one we want to resume from.
|
||||||
# We will start the simulation *at* the resume_from_phase.
|
# We will start the simulation *at* the resume_from_phase.
|
||||||
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
|
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
|
||||||
# Truncate the list to exclude everything after the resume phase
|
# Truncate the list to exclude everything after the resume phase
|
||||||
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
# Note: the state saved for a given phase represents the state at the beginning of that phase.
|
||||||
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
|
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
|
||||||
|
|
||||||
# Wipe any data that must be regenerated.
|
# Wipe any data that must be regenerated.
|
||||||
for key in ("orders", "results", "messages"):
|
for key in ("orders", "results", "messages"):
|
||||||
saved_game_data['phases'][-1].pop(key, None)
|
saved_game_data["phases"][-1].pop(key, None)
|
||||||
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# If the phase is not found, maybe it's the first phase (S1901M)
|
# If the phase is not found, maybe it's the first phase (S1901M)
|
||||||
if resume_from_phase == "S1901M":
|
if resume_from_phase == "S1901M":
|
||||||
saved_game_data['phases'] = []
|
saved_game_data["phases"] = []
|
||||||
logger.info("Resuming from S1901M. Starting with a clean history.")
|
logger.info("Resuming from S1901M. Starting with a clean history.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
|
||||||
|
|
||||||
# Reconstruct the Game object
|
# Reconstruct the Game object
|
||||||
last_phase = saved_game_data['phases'][-1]
|
last_phase = saved_game_data["phases"][-1]
|
||||||
|
|
||||||
# Wipe the data that must be regenerated **but preserve the keys**
|
# Wipe the data that must be regenerated **but preserve the keys**
|
||||||
last_phase['orders'] = {} # was dict
|
last_phase["orders"] = {} # was dict
|
||||||
last_phase['results'] = {} # was dict
|
last_phase["results"] = {} # was dict
|
||||||
last_phase['messages'] = []
|
last_phase["messages"] = []
|
||||||
|
|
||||||
game = from_saved_game_format(saved_game_data)
|
game = from_saved_game_format(saved_game_data)
|
||||||
|
|
||||||
game.phase_summaries = saved_game_data.get('phase_summaries', {})
|
game.phase_summaries = saved_game_data.get("phase_summaries", {})
|
||||||
|
|
||||||
# Reconstruct agents and game history from the *last* valid phase in the data
|
# Reconstruct agents and game history from the *last* valid phase in the data
|
||||||
if not saved_game_data['phases']:
|
if not saved_game_data["phases"]:
|
||||||
# This happens if we are resuming from the very beginning (S1901M)
|
# This happens if we are resuming from the very beginning (S1901M)
|
||||||
logger.info("No previous phases found. Initializing fresh agents and history.")
|
logger.info("No previous phases found. Initializing fresh agents and history.")
|
||||||
agents = {} # Will be created by the main loop
|
agents = {} # Will be created by the main loop
|
||||||
game_history = GameHistory()
|
game_history = GameHistory()
|
||||||
else:
|
else:
|
||||||
# We save the game state up to & including the current (uncompleted) phase.
|
# We save the game state up to & including the current (uncompleted) phase.
|
||||||
# So we need to grab the agent state from the previous (completed) phase.
|
# So we need to grab the agent state from the previous (completed) phase.
|
||||||
if len(saved_game_data['phases']) <= 1:
|
if len(saved_game_data["phases"]) <= 1:
|
||||||
last_phase_data = {}
|
last_phase_data = {}
|
||||||
else:
|
else:
|
||||||
last_phase_data = saved_game_data['phases'][-2]
|
last_phase_data = saved_game_data["phases"][-2]
|
||||||
|
|
||||||
# -------------------- Rebuild agents -------------------- #
|
# -------------------- Rebuild agents -------------------- #
|
||||||
agents = {}
|
agents = {}
|
||||||
|
|
@ -269,18 +250,11 @@ def load_game_state(
|
||||||
elif len(provided) == 1:
|
elif len(provided) == 1:
|
||||||
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
|
||||||
f"Invalid --models argument: expected 1 or {len(powers_order)} "
|
|
||||||
f"items, got {len(provided)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if "state_agents" in last_phase_data:
|
if "state_agents" in last_phase_data:
|
||||||
logger.info("Rebuilding agents from saved state...")
|
logger.info("Rebuilding agents from saved state...")
|
||||||
prompts_dir_from_config = (
|
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
|
||||||
run_config.prompts_dir
|
|
||||||
if run_config and hasattr(run_config, "prompts_dir")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
for power_name, agent_data in last_phase_data["state_agents"].items():
|
for power_name, agent_data in last_phase_data["state_agents"].items():
|
||||||
override_id = power_model_map.get(power_name)
|
override_id = power_model_map.get(power_name)
|
||||||
agents[power_name] = deserialize_agent(
|
agents[power_name] = deserialize_agent(
|
||||||
|
|
@ -290,37 +264,27 @@ def load_game_state(
|
||||||
)
|
)
|
||||||
logger.info(f"Rebuilt {len(agents)} agents.")
|
logger.info(f"Rebuilt {len(agents)} agents.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
|
||||||
"Cannot resume: 'state_agents' key not found in the last phase "
|
|
||||||
"of the save file."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Rebuild GameHistory
|
# Rebuild GameHistory
|
||||||
game_history = GameHistory()
|
game_history = GameHistory()
|
||||||
logger.info("Rebuilding game history...")
|
logger.info("Rebuilding game history...")
|
||||||
for phase_data in saved_game_data['phases'][:-1]:
|
for phase_data in saved_game_data["phases"][:-1]:
|
||||||
phase_name = phase_data['name']
|
phase_name = phase_data["name"]
|
||||||
game_history.add_phase(phase_name)
|
game_history.add_phase(phase_name)
|
||||||
# Add messages
|
# Add messages
|
||||||
for msg in phase_data.get('messages', []):
|
for msg in phase_data.get("messages", []):
|
||||||
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
|
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
|
||||||
# Add plans
|
# Add plans
|
||||||
if 'state_history_plans' in phase_data:
|
if "state_history_plans" in phase_data:
|
||||||
for p_name, plan in phase_data['state_history_plans'].items():
|
for p_name, plan in phase_data["state_history_plans"].items():
|
||||||
game_history.add_plan(phase_name, p_name, plan)
|
game_history.add_plan(phase_name, p_name, plan)
|
||||||
logger.info("Game history rebuilt.")
|
logger.info("Game history rebuilt.")
|
||||||
|
|
||||||
|
|
||||||
return game, agents, game_history, run_config
|
return game, agents, game_history, run_config
|
||||||
|
|
||||||
|
|
||||||
async def initialize_new_game(
|
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
|
||||||
args: Namespace,
|
|
||||||
game: Game,
|
|
||||||
game_history: GameHistory,
|
|
||||||
llm_log_file_path: str
|
|
||||||
) -> Dict[str, DiplomacyAgent]:
|
|
||||||
"""Initializes agents for a new game."""
|
"""Initializes agents for a new game."""
|
||||||
powers_order = sorted(list(ALL_POWERS))
|
powers_order = sorted(list(ALL_POWERS))
|
||||||
|
|
||||||
|
|
@ -370,11 +334,11 @@ async def initialize_new_game(
|
||||||
|
|
||||||
initialized_powers = list(agents.keys())
|
initialized_powers = list(agents.keys())
|
||||||
for i, result in enumerate(initialization_results):
|
for i, result in enumerate(initialization_results):
|
||||||
if i < len(initialized_powers):
|
if i < len(initialized_powers):
|
||||||
power_name = initialized_powers[i]
|
power_name = initialized_powers[i]
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
logger.error(f"Failed to initialize agent state for {power_name}: {result}", exc_info=result)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Successfully initialized agent state for {power_name}.")
|
logger.info(f"Successfully initialized agent state for {power_name}.")
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
# ai_diplomacy/initialization.py
|
# ai_diplomacy/initialization.py
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
# Forward declaration for type hinting, actual imports in function if complex
|
# Forward declaration for type hinting, actual imports in function if complex
|
||||||
if False: # TYPE_CHECKING
|
if False: # TYPE_CHECKING
|
||||||
from diplomacy import Game
|
from diplomacy import Game
|
||||||
from diplomacy.models.game import GameHistory
|
from diplomacy.models.game import GameHistory
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
|
|
@ -18,40 +17,40 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def initialize_agent_state_ext(
|
async def initialize_agent_state_ext(
|
||||||
agent: 'DiplomacyAgent',
|
agent: "DiplomacyAgent",
|
||||||
game: 'Game',
|
game: "Game",
|
||||||
game_history: 'GameHistory',
|
game_history: "GameHistory",
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
prompts_dir: Optional[str] = None,
|
prompts_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Uses the LLM to set initial goals and relationships for the agent."""
|
"""Uses the LLM to set initial goals and relationships for the agent."""
|
||||||
power_name = agent.power_name
|
power_name = agent.power_name
|
||||||
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)..." )
|
logger.info(f"[{power_name}] Initializing agent state using LLM (external function)...")
|
||||||
current_phase = game.get_current_phase() if game else "UnknownPhase"
|
current_phase = game.get_current_phase() if game else "UnknownPhase"
|
||||||
|
|
||||||
full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block
|
full_prompt = "" # Ensure full_prompt is defined in the outer scope for finally block
|
||||||
response = "" # Ensure response is defined for finally block
|
response = "" # Ensure response is defined for finally block
|
||||||
success_status = "Failure: Initialized" # Default status
|
success_status = "Failure: Initialized" # Default status
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the prompt template
|
# Load the prompt template
|
||||||
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
|
||||||
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
|
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
|
||||||
if not initial_prompt_template:
|
if not initial_prompt_template:
|
||||||
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Format the prompt with variables
|
# Format the prompt with variables
|
||||||
initial_prompt = initial_prompt_template.format(
|
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
|
||||||
power_name=power_name,
|
|
||||||
allowed_labels_str=allowed_labels_str
|
|
||||||
)
|
|
||||||
|
|
||||||
board_state = game.get_state() if game else {}
|
board_state = game.get_state() if game else {}
|
||||||
possible_orders = game.get_all_possible_orders() if game else {}
|
possible_orders = game.get_all_possible_orders() if game else {}
|
||||||
|
|
||||||
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
|
logger.debug(
|
||||||
|
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
|
||||||
|
)
|
||||||
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
|
||||||
# For initialization, game should always be present.
|
# For initialization, game should always be present.
|
||||||
|
|
||||||
|
|
@ -76,9 +75,9 @@ async def initialize_agent_state_ext(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=current_phase,
|
phase=current_phase,
|
||||||
response_type='initialization', # Context for run_llm_and_log internal error logging
|
response_type="initialization", # Context for run_llm_and_log internal error logging
|
||||||
)
|
)
|
||||||
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
|
||||||
|
|
||||||
parsed_successfully = False
|
parsed_successfully = False
|
||||||
try:
|
try:
|
||||||
|
|
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
|
||||||
if config.USE_UNFORMATTED_PROMPTS:
|
if config.USE_UNFORMATTED_PROMPTS:
|
||||||
# Format the natural language response into JSON
|
# Format the natural language response into JSON
|
||||||
formatted_response = await format_with_gemini_flash(
|
formatted_response = await format_with_gemini_flash(
|
||||||
response,
|
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
|
||||||
FORMAT_INITIAL_STATE,
|
|
||||||
power_name=power_name,
|
|
||||||
phase=current_phase,
|
|
||||||
log_file_path=log_file_path
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the raw response directly (already formatted)
|
# Use the raw response directly (already formatted)
|
||||||
|
|
@ -101,20 +96,24 @@ async def initialize_agent_state_ext(
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...")
|
logger.error(f"[{power_name}] All JSON extraction attempts failed: {e}. Response snippet: {response[:300]}...")
|
||||||
success_status = "Failure: JSONDecodeError"
|
success_status = "Failure: JSONDecodeError"
|
||||||
update_data = {} # Ensure update_data exists for fallback logic below
|
update_data = {} # Ensure update_data exists for fallback logic below
|
||||||
parsed_successfully = False # Explicitly set here too
|
parsed_successfully = False # Explicitly set here too
|
||||||
# Fallback logic for goals/relationships will be handled later if update_data is empty
|
# Fallback logic for goals/relationships will be handled later if update_data is empty
|
||||||
|
|
||||||
# Defensive check for update_data type if parsing was initially considered successful
|
# Defensive check for update_data type if parsing was initially considered successful
|
||||||
if parsed_successfully:
|
if parsed_successfully:
|
||||||
if isinstance(update_data, str):
|
if isinstance(update_data, str):
|
||||||
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
|
logger.error(
|
||||||
update_data = {} # Treat as parsing failure
|
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
|
||||||
|
)
|
||||||
|
update_data = {} # Treat as parsing failure
|
||||||
parsed_successfully = False
|
parsed_successfully = False
|
||||||
success_status = "Failure: ParsedAsStr"
|
success_status = "Failure: ParsedAsStr"
|
||||||
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
|
||||||
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
|
logger.error(
|
||||||
update_data = {} # Treat as parsing failure
|
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
|
||||||
|
)
|
||||||
|
update_data = {} # Treat as parsing failure
|
||||||
parsed_successfully = False
|
parsed_successfully = False
|
||||||
success_status = "Failure: NotADict"
|
success_status = "Failure: NotADict"
|
||||||
|
|
||||||
|
|
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
|
||||||
initial_relationships_applied = False
|
initial_relationships_applied = False
|
||||||
|
|
||||||
if parsed_successfully:
|
if parsed_successfully:
|
||||||
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
|
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
|
||||||
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
|
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
|
||||||
|
|
||||||
if isinstance(initial_goals, list) and initial_goals:
|
if isinstance(initial_goals, list) and initial_goals:
|
||||||
agent.goals = initial_goals
|
agent.goals = initial_goals
|
||||||
|
|
@ -152,25 +151,25 @@ async def initialize_agent_state_ext(
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{power_name}] No valid relationships found in LLM response.")
|
logger.warning(f"[{power_name}] No valid relationships found in LLM response.")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
|
logger.warning(f"[{power_name}] LLM did not provide valid 'initial_relationships' dict (got: {initial_relationships}).")
|
||||||
|
|
||||||
if initial_goals_applied or initial_relationships_applied:
|
if initial_goals_applied or initial_relationships_applied:
|
||||||
success_status = "Success: Applied LLM data"
|
success_status = "Success: Applied LLM data"
|
||||||
elif parsed_successfully: # Parsed but nothing useful to apply
|
elif parsed_successfully: # Parsed but nothing useful to apply
|
||||||
success_status = "Success: Parsed but no data applied"
|
success_status = "Success: Parsed but no data applied"
|
||||||
# If not parsed_successfully, success_status is already "Failure: JSONDecodeError"
|
# If not parsed_successfully, success_status is already "Failure: JSONDecodeError"
|
||||||
|
|
||||||
# Fallback if LLM data was not applied or parsing failed
|
# Fallback if LLM data was not applied or parsing failed
|
||||||
if not initial_goals_applied:
|
if not initial_goals_applied:
|
||||||
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
|
if not agent.goals: # Only set defaults if no goals were set during agent construction or by LLM
|
||||||
agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"]
|
agent.goals = ["Survive and expand", "Form beneficial alliances", "Secure key territories"]
|
||||||
agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.")
|
agent.add_journal_entry(f"[{current_phase}] Set default initial goals as LLM provided none or parse failed.")
|
||||||
logger.info(f"[{power_name}] Default goals set.")
|
logger.info(f"[{power_name}] Default goals set.")
|
||||||
|
|
||||||
if not initial_relationships_applied:
|
if not initial_relationships_applied:
|
||||||
# Check if relationships are still default-like before overriding
|
# Check if relationships are still default-like before overriding
|
||||||
is_default_relationships = True
|
is_default_relationships = True
|
||||||
if agent.relationships: # Check if it's not empty
|
if agent.relationships: # Check if it's not empty
|
||||||
for p in ALL_POWERS:
|
for p in ALL_POWERS:
|
||||||
if p != power_name and agent.relationships.get(p) != "Neutral":
|
if p != power_name and agent.relationships.get(p) != "Neutral":
|
||||||
is_default_relationships = False
|
is_default_relationships = False
|
||||||
|
|
@ -191,16 +190,16 @@ async def initialize_agent_state_ext(
|
||||||
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
|
agent.relationships = {p: "Neutral" for p in ALL_POWERS if p != power_name}
|
||||||
logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}")
|
logger.info(f"[{power_name}] Set fallback neutral relationships after top-level error: {agent.relationships}")
|
||||||
finally:
|
finally:
|
||||||
if log_file_path: # Ensure log_file_path is provided
|
if log_file_path: # Ensure log_file_path is provided
|
||||||
log_llm_response(
|
log_llm_response(
|
||||||
log_file_path=log_file_path,
|
log_file_path=log_file_path,
|
||||||
model_name=agent.client.model_name if agent and agent.client else "UnknownModel",
|
model_name=agent.client.model_name if agent and agent.client else "UnknownModel",
|
||||||
power_name=power_name,
|
power_name=power_name,
|
||||||
phase=current_phase,
|
phase=current_phase,
|
||||||
response_type="initial_state_setup", # Specific type for CSV logging
|
response_type="initial_state_setup", # Specific type for CSV logging
|
||||||
raw_input_prompt=full_prompt,
|
raw_input_prompt=full_prompt,
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
success=success_status
|
success=success_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final log of state after initialization attempt
|
# Final log of state after initialization attempt
|
||||||
|
|
|
||||||
|
|
@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
|
||||||
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
|
||||||
|
|
||||||
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
|
||||||
from diplomacy.engine.message import Message, GLOBAL
|
from diplomacy.engine.message import Message, GLOBAL
|
||||||
|
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
from .clients import load_model_client
|
from .utils import gather_possible_orders, normalize_recipient_name
|
||||||
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
|
|
@ -21,9 +20,9 @@ load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def conduct_negotiations(
|
async def conduct_negotiations(
|
||||||
game: 'Game',
|
game: "Game",
|
||||||
agents: Dict[str, DiplomacyAgent],
|
agents: Dict[str, DiplomacyAgent],
|
||||||
game_history: 'GameHistory',
|
game_history: "GameHistory",
|
||||||
model_error_stats: Dict[str, Dict[str, int]],
|
model_error_stats: Dict[str, Dict[str, int]],
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
max_rounds: int = 3,
|
max_rounds: int = 3,
|
||||||
|
|
@ -35,12 +34,8 @@ async def conduct_negotiations(
|
||||||
"""
|
"""
|
||||||
logger.info("Starting negotiation phase.")
|
logger.info("Starting negotiation phase.")
|
||||||
|
|
||||||
active_powers = [
|
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||||
]
|
|
||||||
eliminated_powers = [
|
|
||||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Active powers for negotiations: {active_powers}")
|
logger.info(f"Active powers for negotiations: {active_powers}")
|
||||||
if eliminated_powers:
|
if eliminated_powers:
|
||||||
|
|
@ -99,28 +94,28 @@ async def conduct_negotiations(
|
||||||
# Process results
|
# Process results
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
power_name = power_names_for_tasks[i]
|
power_name = power_names_for_tasks[i]
|
||||||
agent = agents[power_name] # Get agent again for journaling
|
agent = agents[power_name] # Get agent again for journaling
|
||||||
model_name = agent.client.model_name # Get model name for stats
|
model_name = agent.client.model_name # Get model name for stats
|
||||||
|
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
|
logger.error(f"Error getting conversation reply for {power_name}: {result}", exc_info=result)
|
||||||
# Use model_name for stats key if possible
|
# Use model_name for stats key if possible
|
||||||
if model_name in model_error_stats:
|
if model_name in model_error_stats:
|
||||||
model_error_stats[model_name]["conversation_errors"] += 1
|
model_error_stats[model_name]["conversation_errors"] += 1
|
||||||
else: # Fallback to power_name if model name not tracked (shouldn't happen)
|
else: # Fallback to power_name if model name not tracked (shouldn't happen)
|
||||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||||
model_error_stats[power_name]["conversation_errors"] += 1
|
model_error_stats[power_name]["conversation_errors"] += 1
|
||||||
messages = [] # Treat as no messages on error
|
messages = [] # Treat as no messages on error
|
||||||
elif result is None: # Handle case where client might return None on internal error
|
elif result is None: # Handle case where client might return None on internal error
|
||||||
logger.warning(f"Received None instead of messages for {power_name}.")
|
logger.warning(f"Received None instead of messages for {power_name}.")
|
||||||
messages = []
|
messages = []
|
||||||
if model_name in model_error_stats:
|
if model_name in model_error_stats:
|
||||||
model_error_stats[model_name]["conversation_errors"] += 1
|
model_error_stats[model_name]["conversation_errors"] += 1
|
||||||
else:
|
else:
|
||||||
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
model_error_stats.setdefault(power_name, {}).setdefault("conversation_errors", 0)
|
||||||
model_error_stats[power_name]["conversation_errors"] += 1
|
model_error_stats[power_name]["conversation_errors"] += 1
|
||||||
else:
|
else:
|
||||||
messages = result # result is the list of message dicts
|
messages = result # result is the list of message dicts
|
||||||
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
|
logger.debug(f"Received {len(messages)} message(s) from {power_name}.")
|
||||||
|
|
||||||
# Process the received messages (same logic as before)
|
# Process the received messages (same logic as before)
|
||||||
|
|
@ -134,27 +129,27 @@ async def conduct_negotiations(
|
||||||
# Create an official message in the Diplomacy engine
|
# Create an official message in the Diplomacy engine
|
||||||
# Determine recipient based on message type
|
# Determine recipient based on message type
|
||||||
if message.get("message_type") == "private":
|
if message.get("message_type") == "private":
|
||||||
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
|
recipient = normalize_recipient_name(message.get("recipient", GLOBAL)) # Default to GLOBAL if recipient missing somehow
|
||||||
if recipient not in game.powers and recipient != GLOBAL:
|
if recipient not in game.powers and recipient != GLOBAL:
|
||||||
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
|
logger.warning(f"Invalid recipient '{recipient}' in message from {power_name}. Sending globally.")
|
||||||
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
|
recipient = GLOBAL # Fallback to GLOBAL if recipient power is invalid
|
||||||
else: # Assume global if not private or type is missing
|
else: # Assume global if not private or type is missing
|
||||||
recipient = GLOBAL
|
recipient = GLOBAL
|
||||||
|
|
||||||
diplo_message = Message(
|
diplo_message = Message(
|
||||||
phase=game.current_short_phase,
|
phase=game.current_short_phase,
|
||||||
sender=power_name,
|
sender=power_name,
|
||||||
recipient=recipient, # Use determined recipient
|
recipient=recipient, # Use determined recipient
|
||||||
message=message.get("content", ""), # Use .get for safety
|
message=message.get("content", ""), # Use .get for safety
|
||||||
time_sent=None, # Let the engine assign time
|
time_sent=None, # Let the engine assign time
|
||||||
)
|
)
|
||||||
game.add_message(diplo_message)
|
game.add_message(diplo_message)
|
||||||
# Also add to our custom history
|
# Also add to our custom history
|
||||||
game_history.add_message(
|
game_history.add_message(
|
||||||
game.current_short_phase,
|
game.current_short_phase,
|
||||||
power_name,
|
power_name,
|
||||||
recipient, # Use determined recipient here too
|
recipient, # Use determined recipient here too
|
||||||
message.get("content", ""), # Use .get for safety
|
message.get("content", ""), # Use .get for safety
|
||||||
)
|
)
|
||||||
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
journal_recipient = f"to {recipient}" if recipient != GLOBAL else "globally"
|
||||||
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")
|
agent.add_journal_entry(f"Sent message {journal_recipient} in {game.current_short_phase}: {message.get('content', '')[:100]}...")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
from dotenv import load_dotenv
|
|
||||||
import logging
|
import logging
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from .clients import load_model_client
|
|
||||||
from .game_history import GameHistory
|
from .game_history import GameHistory
|
||||||
from .agent import DiplomacyAgent
|
from .agent import DiplomacyAgent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def planning_phase(
|
async def planning_phase(
|
||||||
game,
|
game,
|
||||||
agents: Dict[str, DiplomacyAgent],
|
agents: Dict[str, DiplomacyAgent],
|
||||||
|
|
@ -20,12 +19,8 @@ async def planning_phase(
|
||||||
Lets each power generate a strategic plan using their DiplomacyAgent.
|
Lets each power generate a strategic plan using their DiplomacyAgent.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
logger.info(f"Starting planning phase for {game.current_short_phase}...")
|
||||||
active_powers = [
|
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
|
||||||
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
|
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
|
||||||
]
|
|
||||||
eliminated_powers = [
|
|
||||||
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Active powers for planning: {active_powers}")
|
logger.info(f"Active powers for planning: {active_powers}")
|
||||||
if eliminated_powers:
|
if eliminated_powers:
|
||||||
|
|
@ -35,9 +30,7 @@ async def planning_phase(
|
||||||
|
|
||||||
board_state = game.get_state()
|
board_state = game.get_state()
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
|
||||||
max_workers=len(active_powers)
|
|
||||||
) as executor:
|
|
||||||
futures = {}
|
futures = {}
|
||||||
for power_name in active_powers:
|
for power_name in active_powers:
|
||||||
if power_name not in agents:
|
if power_name not in agents:
|
||||||
|
|
@ -69,18 +62,16 @@ async def planning_phase(
|
||||||
logger.info(f"Received planning result from {power_name}.")
|
logger.info(f"Received planning result from {power_name}.")
|
||||||
|
|
||||||
if plan_result.startswith("Error:"):
|
if plan_result.startswith("Error:"):
|
||||||
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
|
||||||
if power_name in model_error_stats:
|
if power_name in model_error_stats:
|
||||||
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
|
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
|
||||||
model_error_stats[power_name]['planning_generation_errors'] += 1
|
model_error_stats[power_name]["planning_generation_errors"] += 1
|
||||||
else:
|
else:
|
||||||
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
|
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
|
||||||
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
|
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
|
||||||
elif plan_result:
|
elif plan_result:
|
||||||
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
|
||||||
game_history.add_plan(
|
game_history.add_plan(game.current_short_phase, power_name, plan_result)
|
||||||
game.current_short_phase, power_name, plan_result
|
|
||||||
)
|
|
||||||
logger.debug(f"Added plan for {power_name} to history.")
|
logger.debug(f"Added plan for {power_name} to history.")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Agent {power_name} returned an empty plan.")
|
logger.warning(f"Agent {power_name} returned an empty plan.")
|
||||||
|
|
@ -88,11 +79,11 @@ async def planning_phase(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
logger.error(f"Exception during planning result processing for {power_name}: {e}")
|
||||||
if power_name in model_error_stats:
|
if power_name in model_error_stats:
|
||||||
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
|
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
|
||||||
model_error_stats[power_name]['planning_execution_errors'] += 1
|
model_error_stats[power_name]["planning_execution_errors"] += 1
|
||||||
else:
|
else:
|
||||||
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
|
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
|
||||||
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
|
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
|
||||||
|
|
||||||
logger.info("Planning phase processing complete.")
|
logger.info("Planning phase processing complete.")
|
||||||
return game_history
|
return game_history
|
||||||
|
|
@ -14,6 +14,7 @@ import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
"""
|
"""
|
||||||
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
|
||||||
|
|
@ -22,20 +23,18 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
"""
|
"""
|
||||||
# ── collect all 3-letter province codes ───────────────────────────────
|
# ── collect all 3-letter province codes ───────────────────────────────
|
||||||
provs: Set[str] = {
|
provs: Set[str] = {
|
||||||
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
|
loc.split("/")[0][:3].upper() # 'BUL/EC' -> 'BUL'
|
||||||
for loc in game_map.locs
|
for loc in game_map.locs
|
||||||
if len(loc.split("/")[0]) == 3
|
if len(loc.split("/")[0]) == 3
|
||||||
}
|
}
|
||||||
|
|
||||||
graph: Dict[str, Dict[str, List[str]]] = {
|
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
|
||||||
p: {"ARMY": [], "FLEET": []} for p in provs
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── helper: list every concrete variant of a province ─────────────────
|
# ── helper: list every concrete variant of a province ─────────────────
|
||||||
def variants(code: str) -> List[str]:
|
def variants(code: str) -> List[str]:
|
||||||
lst = list(game_map.loc_coasts.get(code, []))
|
lst = list(game_map.loc_coasts.get(code, []))
|
||||||
if code not in lst:
|
if code not in lst:
|
||||||
lst.append(code) # ensure base node included
|
lst.append(code) # ensure base node included
|
||||||
return lst
|
return lst
|
||||||
|
|
||||||
# ── populate adjacency by brute-force queries to `abuts()` ────────────
|
# ── populate adjacency by brute-force queries to `abuts()` ────────────
|
||||||
|
|
@ -55,42 +54,37 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
|
||||||
graph[src]["ARMY"].append(dest)
|
graph[src]["ARMY"].append(dest)
|
||||||
|
|
||||||
# FLEET — any src variant that can host a fleet is valid
|
# FLEET — any src variant that can host a fleet is valid
|
||||||
if any(
|
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
|
||||||
game_map.abuts("F", sv, "-", dv)
|
|
||||||
for sv in src_vers
|
|
||||||
for dv in dest_vers
|
|
||||||
):
|
|
||||||
graph[src]["FLEET"].append(dest)
|
graph[src]["FLEET"].append(dest)
|
||||||
|
|
||||||
# ── tidy up duplicates / order ---------------------------------------
|
# ── tidy up duplicates / order ---------------------------------------
|
||||||
for p in graph:
|
for p in graph:
|
||||||
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
|
graph[p]["ARMY"] = sorted(set(graph[p]["ARMY"]))
|
||||||
graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"]))
|
graph[p]["FLEET"] = sorted(set(graph[p]["FLEET"]))
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def bfs_shortest_path(
|
def bfs_shortest_path(
|
||||||
graph: Dict[str, Dict[str, List[str]]],
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
game_map: GameMap, # Added game_map
|
game_map: GameMap, # Added game_map
|
||||||
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
|
||||||
unit_type: str,
|
unit_type: str,
|
||||||
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
|
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
|
||||||
) -> Optional[List[str]]: # Returns path of SHORT names
|
) -> Optional[List[str]]: # Returns path of SHORT names
|
||||||
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
|
||||||
|
|
||||||
# Convert full start location to short province name
|
# Convert full start location to short province name
|
||||||
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
|
||||||
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
|
||||||
start_loc_short = start_loc_short[:3]
|
start_loc_short = start_loc_short[:3]
|
||||||
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
|
||||||
# A simpler way for non-coastal full (like 'VIE') or already short:
|
# A simpler way for non-coastal full (like 'VIE') or already short:
|
||||||
if '/' not in start_loc_full:
|
if "/" not in start_loc_full:
|
||||||
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
|
||||||
else: # Has '/', e.g. 'STP/SC'
|
else: # Has '/', e.g. 'STP/SC'
|
||||||
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
|
||||||
|
|
||||||
if start_loc_short not in graph:
|
if start_loc_short not in graph:
|
||||||
logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.")
|
logger.warning(f"BFS: Start province {start_loc_short} (from {start_loc_full}) not in graph. Pathfinding may fail.")
|
||||||
|
|
@ -104,14 +98,14 @@ def bfs_shortest_path(
|
||||||
|
|
||||||
# is_target_func expects a short location name
|
# is_target_func expects a short location name
|
||||||
if is_target_func(current_loc_short, board_state):
|
if is_target_func(current_loc_short, board_state):
|
||||||
return path # Path of short names
|
return path # Path of short names
|
||||||
|
|
||||||
# possible_neighbors are SHORT names from the graph
|
# possible_neighbors are SHORT names from the graph
|
||||||
possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, [])
|
possible_neighbors_short = graph.get(current_loc_short, {}).get(unit_type, [])
|
||||||
|
|
||||||
for next_loc_short in possible_neighbors_short:
|
for next_loc_short in possible_neighbors_short:
|
||||||
if next_loc_short not in visited_nodes:
|
if next_loc_short not in visited_nodes:
|
||||||
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
|
if next_loc_short not in graph: # Defensive check for neighbors not in graph keys
|
||||||
logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.")
|
logger.warning(f"BFS: Neighbor {next_loc_short} of {current_loc_short} not in graph. Skipping.")
|
||||||
continue
|
continue
|
||||||
visited_nodes.add(next_loc_short)
|
visited_nodes.add(next_loc_short)
|
||||||
|
|
@ -119,11 +113,12 @@ def bfs_shortest_path(
|
||||||
queue.append((next_loc_short, new_path))
|
queue.append((next_loc_short, new_path))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# --- Helper functions for context generation ---
|
# --- Helper functions for context generation ---
|
||||||
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
|
||||||
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
|
||||||
for power, unit_list in board_state.get('units', {}).items():
|
for power, unit_list in board_state.get("units", {}).items():
|
||||||
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
|
||||||
parts = unit_str.split(" ")
|
parts = unit_str.split(" ")
|
||||||
if len(parts) == 2:
|
if len(parts) == 2:
|
||||||
unit_map_loc = parts[1]
|
unit_map_loc = parts[1]
|
||||||
|
|
@ -131,24 +126,26 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
|
||||||
return f"{parts[0]} {location} ({power})"
|
return f"{parts[0]} {location} ({power})"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
|
||||||
"""Returns the controlling power's name if the location is an SC, else None."""
|
"""Returns the controlling power's name if the location is an SC, else None."""
|
||||||
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
|
||||||
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
|
||||||
if loc_province_name not in game_map.scs:
|
if loc_province_name not in game_map.scs:
|
||||||
return None
|
return None
|
||||||
for power, sc_list in board_state.get('centers', {}).items():
|
for power, sc_list in board_state.get("centers", {}).items():
|
||||||
if loc_province_name in sc_list:
|
if loc_province_name in sc_list:
|
||||||
return power
|
return power
|
||||||
return None # Unowned SC
|
return None # Unowned SC
|
||||||
|
|
||||||
|
|
||||||
def get_shortest_path_to_friendly_unit(
|
def get_shortest_path_to_friendly_unit(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
graph: Dict[str, Dict[str, List[str]]],
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
game_map: GameMap, # Added game_map
|
game_map: GameMap, # Added game_map
|
||||||
power_name: str,
|
power_name: str,
|
||||||
start_unit_loc_full: str,
|
start_unit_loc_full: str,
|
||||||
start_unit_type: str
|
start_unit_type: str,
|
||||||
) -> Optional[Tuple[str, List[str]]]:
|
) -> Optional[Tuple[str, List[str]]]:
|
||||||
"""Finds the shortest path to any friendly unit of the same power."""
|
"""Finds the shortest path to any friendly unit of the same power."""
|
||||||
|
|
||||||
|
|
@ -162,7 +159,7 @@ def get_shortest_path_to_friendly_unit(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly)
|
path_short_names = bfs_shortest_path(graph, board_state, game_map, start_unit_loc_full, start_unit_type, is_target_friendly)
|
||||||
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
|
if path_short_names and len(path_short_names) > 1: # Path includes start, so > 1 means a distinct friendly unit found
|
||||||
target_loc_short = path_short_names[-1]
|
target_loc_short = path_short_names[-1]
|
||||||
# Find the actual friendly unit string at one of the full locations of target_loc_short
|
# Find the actual friendly unit string at one of the full locations of target_loc_short
|
||||||
friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT"
|
friendly_unit_str = "UNKNOWN_FRIENDLY_UNIT"
|
||||||
|
|
@ -179,35 +176,35 @@ def get_shortest_path_to_friendly_unit(
|
||||||
def get_nearest_enemy_units(
|
def get_nearest_enemy_units(
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
graph: Dict[str, Dict[str, List[str]]],
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
game_map: GameMap, # Added game_map
|
game_map: GameMap, # Added game_map
|
||||||
power_name: str,
|
power_name: str,
|
||||||
start_unit_loc_full: str,
|
start_unit_loc_full: str,
|
||||||
start_unit_type: str,
|
start_unit_type: str,
|
||||||
n: int = 3
|
n: int = 3,
|
||||||
) -> List[Tuple[str, List[str]]]:
|
) -> List[Tuple[str, List[str]]]:
|
||||||
"""Finds up to N nearest enemy units, sorted by path length."""
|
"""Finds up to N nearest enemy units, sorted by path length."""
|
||||||
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
|
||||||
|
|
||||||
all_enemy_unit_locations_full: List[Tuple[str,str]] = [] # (loc_full, unit_str_full)
|
all_enemy_unit_locations_full: List[Tuple[str, str]] = [] # (loc_full, unit_str_full)
|
||||||
# board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... }
|
# board_state.get("units", {}) has format: { "POWER_NAME": ["A PAR", "F BRE"], ... }
|
||||||
for p_name, unit_list_for_power in board_state.get("units", {}).items():
|
for p_name, unit_list_for_power in board_state.get("units", {}).items():
|
||||||
if p_name != power_name: # If it's an enemy power
|
if p_name != power_name: # If it's an enemy power
|
||||||
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
|
for unit_repr_from_state in unit_list_for_power: # e.g., "A PAR" or "F STP/SC"
|
||||||
parts = unit_repr_from_state.split(" ")
|
parts = unit_repr_from_state.split(" ")
|
||||||
if len(parts) == 2:
|
if len(parts) == 2:
|
||||||
# unit_type_char = parts[0] # 'A' or 'F'
|
# unit_type_char = parts[0] # 'A' or 'F'
|
||||||
loc_full = parts[1] # 'PAR' or 'STP/SC'
|
loc_full = parts[1] # 'PAR' or 'STP/SC'
|
||||||
|
|
||||||
# Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)"
|
# Use get_unit_at_location to get the consistent full unit string like "A PAR (POWER_NAME)"
|
||||||
full_unit_str_with_power = get_unit_at_location(board_state, loc_full)
|
full_unit_str_with_power = get_unit_at_location(board_state, loc_full)
|
||||||
if full_unit_str_with_power: # Should find the unit if iteration is correct
|
if full_unit_str_with_power: # Should find the unit if iteration is correct
|
||||||
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
|
all_enemy_unit_locations_full.append((loc_full, full_unit_str_with_power))
|
||||||
|
|
||||||
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
|
||||||
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
|
||||||
if '/' in target_enemy_loc_short:
|
if "/" in target_enemy_loc_short:
|
||||||
target_enemy_loc_short = target_enemy_loc_short[:3]
|
target_enemy_loc_short = target_enemy_loc_short[:3]
|
||||||
if '/' not in target_enemy_loc_full:
|
if "/" not in target_enemy_loc_full:
|
||||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||||
else:
|
else:
|
||||||
target_enemy_loc_short = target_enemy_loc_full[:3]
|
target_enemy_loc_short = target_enemy_loc_full[:3]
|
||||||
|
|
@ -220,7 +217,7 @@ def get_nearest_enemy_units(
|
||||||
if path_short_names:
|
if path_short_names:
|
||||||
enemy_paths.append((enemy_unit_str, path_short_names))
|
enemy_paths.append((enemy_unit_str, path_short_names))
|
||||||
|
|
||||||
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
|
enemy_paths.sort(key=lambda x: len(x[1])) # Sort by path length
|
||||||
return enemy_paths[:n]
|
return enemy_paths[:n]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -242,10 +239,10 @@ def get_nearest_uncontrolled_scs(
|
||||||
"""
|
"""
|
||||||
results: List[Tuple[str, int, List[str]]] = []
|
results: List[Tuple[str, int, List[str]]] = []
|
||||||
|
|
||||||
for sc_short in game_map.scs: # all SC province codes
|
for sc_short in game_map.scs: # all SC province codes
|
||||||
controller = get_sc_controller(game_map, board_state, sc_short)
|
controller = get_sc_controller(game_map, board_state, sc_short)
|
||||||
if controller == power_name:
|
if controller == power_name:
|
||||||
continue # already ours
|
continue # already ours
|
||||||
|
|
||||||
# helper for BFS target test
|
# helper for BFS target test
|
||||||
def is_target(loc_short: str, _state: BoardState) -> bool:
|
def is_target(loc_short: str, _state: BoardState) -> bool:
|
||||||
|
|
@ -260,9 +257,9 @@ def get_nearest_uncontrolled_scs(
|
||||||
is_target,
|
is_target,
|
||||||
)
|
)
|
||||||
if not path:
|
if not path:
|
||||||
continue # unreachable
|
continue # unreachable
|
||||||
|
|
||||||
distance = len(path) - 1 # moves needed
|
distance = len(path) - 1 # moves needed
|
||||||
|
|
||||||
# skip distance 0 (same province) and 1 (adjacent)
|
# skip distance 0 (same province) and 1 (adjacent)
|
||||||
if distance <= 1:
|
if distance <= 1:
|
||||||
|
|
@ -279,9 +276,9 @@ def get_nearest_uncontrolled_scs(
|
||||||
def get_adjacent_territory_details(
|
def get_adjacent_territory_details(
|
||||||
game_map: GameMap,
|
game_map: GameMap,
|
||||||
board_state: BoardState,
|
board_state: BoardState,
|
||||||
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
|
||||||
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
|
||||||
graph: Dict[str, Dict[str, List[str]]]
|
graph: Dict[str, Dict[str, List[str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a string describing adjacent territories and units that can interact with them."""
|
"""Generates a string describing adjacent territories and units that can interact with them."""
|
||||||
output_lines: List[str] = []
|
output_lines: List[str] = []
|
||||||
|
|
@ -291,29 +288,29 @@ def get_adjacent_territory_details(
|
||||||
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
|
||||||
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
|
||||||
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
|
||||||
if '/' in unit_loc_short:
|
if "/" in unit_loc_short:
|
||||||
unit_loc_short = unit_loc_short[:3]
|
unit_loc_short = unit_loc_short[:3]
|
||||||
if '/' not in unit_loc_full:
|
if "/" not in unit_loc_full:
|
||||||
unit_loc_short = unit_loc_full[:3]
|
unit_loc_short = unit_loc_full[:3]
|
||||||
else:
|
else:
|
||||||
unit_loc_short = unit_loc_full[:3]
|
unit_loc_short = unit_loc_full[:3]
|
||||||
|
|
||||||
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
|
adjacent_locs_short_for_unit = graph.get(unit_loc_short, {}).get(unit_type, [])
|
||||||
|
|
||||||
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
|
processed_adj_provinces = set() # To handle cases like STP/NC and STP/SC both being adjacent to BOT
|
||||||
|
|
||||||
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
|
for adj_loc_short in adjacent_locs_short_for_unit: # adj_loc_short is already short
|
||||||
# adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed
|
# adj_province_short = game_map.loc_name.get(adj_loc_full, adj_loc_full).upper()[:3] # No longer needed
|
||||||
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
|
if adj_loc_short in processed_adj_provinces: # adj_loc_short is already short and upper implicitly by map data
|
||||||
continue
|
continue
|
||||||
processed_adj_provinces.add(adj_loc_short)
|
processed_adj_provinces.add(adj_loc_short)
|
||||||
|
|
||||||
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
|
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
|
||||||
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
|
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
|
||||||
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
|
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
|
||||||
elif adj_loc_type == 'WATER':
|
elif adj_loc_type == "WATER":
|
||||||
adj_loc_type_display = 'WATER'
|
adj_loc_type_display = "WATER"
|
||||||
else: # SHUT etc.
|
else: # SHUT etc.
|
||||||
adj_loc_type_display = adj_loc_type
|
adj_loc_type_display = adj_loc_type
|
||||||
|
|
||||||
line = f" {adj_loc_short} ({adj_loc_type_display})"
|
line = f" {adj_loc_short} ({adj_loc_type_display})"
|
||||||
|
|
@ -330,15 +327,14 @@ def get_adjacent_territory_details(
|
||||||
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
|
||||||
# A true "can support/move to" would require checking possible orders of those further units.
|
# A true "can support/move to" would require checking possible orders of those further units.
|
||||||
# further_adj_provinces are short names from the graph
|
# further_adj_provinces are short names from the graph
|
||||||
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
|
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
|
||||||
graph.get(adj_loc_short, {}).get('FLEET', [])
|
|
||||||
|
|
||||||
supporting_units_info = []
|
supporting_units_info = []
|
||||||
processed_further_provinces = set()
|
processed_further_provinces = set()
|
||||||
for further_adj_loc_short in further_adj_provinces_short:
|
for further_adj_loc_short in further_adj_provinces_short:
|
||||||
# further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3]
|
# further_adj_province_short = game_map.loc_name.get(further_adj_loc_full, further_adj_loc_full).upper()[:3]
|
||||||
# No conversion needed, it's already short
|
# No conversion needed, it's already short
|
||||||
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
|
if further_adj_loc_short == adj_loc_short or further_adj_loc_short == unit_loc_short: # Don't list itself or origin
|
||||||
continue
|
continue
|
||||||
if further_adj_loc_short in processed_further_provinces:
|
if further_adj_loc_short in processed_further_provinces:
|
||||||
continue
|
continue
|
||||||
|
|
@ -354,7 +350,7 @@ def get_adjacent_territory_details(
|
||||||
temp_unit = get_unit_at_location(board_state, fv_further)
|
temp_unit = get_unit_at_location(board_state, fv_further)
|
||||||
if temp_unit:
|
if temp_unit:
|
||||||
unit_in_further_loc = temp_unit
|
unit_in_further_loc = temp_unit
|
||||||
break # Found a unit in one of the coasts/base
|
break # Found a unit in one of the coasts/base
|
||||||
|
|
||||||
# if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short:
|
# if not unit_in_further_loc and further_adj_loc_full != further_adj_province_short:
|
||||||
# unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short)
|
# unit_in_further_loc = get_unit_at_location(board_state, further_adj_province_short)
|
||||||
|
|
@ -384,16 +380,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
# Iterate through units that have orders (keys of possible_orders_for_power are unit locations)
|
# Iterate through units that have orders (keys of possible_orders_for_power are unit locations)
|
||||||
for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items():
|
for unit_loc_full, unit_specific_possible_orders in possible_orders_for_power.items():
|
||||||
unit_str_full = get_unit_at_location(board_state, unit_loc_full)
|
unit_str_full = get_unit_at_location(board_state, unit_loc_full)
|
||||||
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
|
if not unit_str_full: # Should not happen if unit_loc_full is from possible_orders keys
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
|
||||||
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
|
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
|
||||||
|
|
||||||
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
|
||||||
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
|
||||||
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
|
if loc_type_short == "COAST" or loc_type_short == "LAND":
|
||||||
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
|
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
|
||||||
else:
|
else:
|
||||||
loc_type_display = loc_type_short
|
loc_type_display = loc_type_short
|
||||||
|
|
||||||
|
|
@ -401,7 +397,7 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
|
||||||
|
|
||||||
# Unit Information section
|
# Unit Information section
|
||||||
current_unit_lines.append(' <UnitInformation>')
|
current_unit_lines.append(" <UnitInformation>")
|
||||||
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
|
||||||
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
|
||||||
if sc_owner_at_loc == power_name:
|
if sc_owner_at_loc == power_name:
|
||||||
|
|
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
header_content += f" (SC controlled by {sc_owner_at_loc})"
|
||||||
current_unit_lines.append(f" {header_content}")
|
current_unit_lines.append(f" {header_content}")
|
||||||
current_unit_lines.append(f" Units present: {unit_str_full}")
|
current_unit_lines.append(f" Units present: {unit_str_full}")
|
||||||
current_unit_lines.append(' </UnitInformation>')
|
current_unit_lines.append(" </UnitInformation>")
|
||||||
|
|
||||||
# Possible moves section
|
# Possible moves section
|
||||||
current_unit_lines.append(' <PossibleMoves>')
|
current_unit_lines.append(" <PossibleMoves>")
|
||||||
current_unit_lines.append(" Possible moves:")
|
current_unit_lines.append(" Possible moves:")
|
||||||
for order_str in unit_specific_possible_orders:
|
for order_str in unit_specific_possible_orders:
|
||||||
current_unit_lines.append(f" {order_str}")
|
current_unit_lines.append(f" {order_str}")
|
||||||
current_unit_lines.append(' </PossibleMoves>')
|
current_unit_lines.append(" </PossibleMoves>")
|
||||||
|
|
||||||
# Nearest enemy units section
|
# Nearest enemy units section
|
||||||
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
|
||||||
current_unit_lines.append(' <NearestEnemyUnits>')
|
current_unit_lines.append(" <NearestEnemyUnits>")
|
||||||
if enemy_units_info:
|
if enemy_units_info:
|
||||||
current_unit_lines.append(" Nearest units (not ours):")
|
current_unit_lines.append(" Nearest units (not ours):")
|
||||||
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
for enemy_unit_str, enemy_path_short in enemy_units_info:
|
||||||
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
|
current_unit_lines.append(
|
||||||
|
f" {enemy_unit_str}, path=[{unit_loc_full}→{('→'.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Nearest units (not ours): None found")
|
current_unit_lines.append(" Nearest units (not ours): None found")
|
||||||
current_unit_lines.append(' </NearestEnemyUnits>')
|
current_unit_lines.append(" </NearestEnemyUnits>")
|
||||||
|
|
||||||
# Nearest supply centers (not controlled by us) section
|
# Nearest supply centers (not controlled by us) section
|
||||||
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
|
||||||
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
|
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
|
||||||
if uncontrolled_scs_info:
|
if uncontrolled_scs_info:
|
||||||
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
|
||||||
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
|
||||||
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
|
current_unit_lines.append(
|
||||||
|
f" {sc_str}, dist={dist}, path=[{unit_loc_full}→{('→'.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
|
||||||
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
|
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
|
||||||
|
|
||||||
# Adjacent territories details section
|
# Adjacent territories details section
|
||||||
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
|
||||||
current_unit_lines.append(' <AdjacentTerritories>')
|
current_unit_lines.append(" <AdjacentTerritories>")
|
||||||
if adj_details_str:
|
if adj_details_str:
|
||||||
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
|
||||||
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
# Assuming adj_details_str is already formatted with newlines and indentation for its content
|
||||||
# We might need to indent adj_details_str if it's a single block of text
|
# We might need to indent adj_details_str if it's a single block of text
|
||||||
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
|
||||||
if '\n' in adj_details_str:
|
if "\n" in adj_details_str:
|
||||||
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
|
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
|
||||||
current_unit_lines.append(indented_adj_details)
|
current_unit_lines.append(indented_adj_details)
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(f" {adj_details_str}")
|
current_unit_lines.append(f" {adj_details_str}")
|
||||||
else:
|
else:
|
||||||
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
|
current_unit_lines.append(
|
||||||
current_unit_lines.append(' </AdjacentTerritories>')
|
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
|
||||||
|
) # Added more descriptive else
|
||||||
|
current_unit_lines.append(" </AdjacentTerritories>")
|
||||||
|
|
||||||
current_unit_lines.append(' </UnitContext>')
|
current_unit_lines.append(" </UnitContext>")
|
||||||
final_context_lines.extend(current_unit_lines)
|
final_context_lines.extend(current_unit_lines)
|
||||||
|
|
||||||
final_context_lines.append("</PossibleOrdersContext>")
|
final_context_lines.append("</PossibleOrdersContext>")
|
||||||
|
|
@ -469,16 +471,16 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
|
||||||
# Regex and tiny helpers
|
# Regex and tiny helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Tuple, List, Dict, Optional, Any
|
from typing import Tuple, List, Dict, Optional, Any
|
||||||
|
|
||||||
# ── order-syntax matchers ─────────────────────────────────────────────────
|
# ── order-syntax matchers ─────────────────────────────────────────────────
|
||||||
_SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$")
|
_SIMPLE_MOVE_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? - [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||||
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
_HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
|
||||||
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
|
||||||
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
|
||||||
|
|
||||||
def _is_hold_order(order: str) -> bool: # NEW
|
|
||||||
|
def _is_hold_order(order: str) -> bool: # NEW
|
||||||
return bool(_HOLD_RE.match(order.strip()))
|
return bool(_HOLD_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
|
||||||
"""Trim & uppercase for reliable comparisons."""
|
"""Trim & uppercase for reliable comparisons."""
|
||||||
return name.strip().upper()
|
return name.strip().upper()
|
||||||
|
|
||||||
|
|
||||||
def _is_simple_move(order: str) -> bool:
|
def _is_simple_move(order: str) -> bool:
|
||||||
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _is_retreat_order(order: str) -> bool:
|
def _is_retreat_order(order: str) -> bool:
|
||||||
return bool(_RETREAT_RE.match(order.strip()))
|
return bool(_RETREAT_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _is_adjust_order(order: str) -> bool:
|
def _is_adjust_order(order: str) -> bool:
|
||||||
return bool(_ADJUST_RE.match(order.strip()))
|
return bool(_ADJUST_RE.match(order.strip()))
|
||||||
|
|
||||||
|
|
||||||
def _split_move(order: str) -> Tuple[str, str]:
|
def _split_move(order: str) -> Tuple[str, str]:
|
||||||
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
|
||||||
unit_part, dest = order.split(" - ")
|
unit_part, dest = order.split(" - ")
|
||||||
|
|
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
|
||||||
# Gather *all* friendly support orders for a given move
|
# Gather *all* friendly support orders for a given move
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _all_support_examples(
|
def _all_support_examples(
|
||||||
mover: str,
|
mover: str,
|
||||||
dest: str,
|
dest: str,
|
||||||
|
|
@ -526,6 +533,7 @@ def _all_support_examples(
|
||||||
|
|
||||||
return supports
|
return supports
|
||||||
|
|
||||||
|
|
||||||
def _all_support_hold_examples(
|
def _all_support_hold_examples(
|
||||||
holder: str,
|
holder: str,
|
||||||
all_orders: Dict[str, List[str]],
|
all_orders: Dict[str, List[str]],
|
||||||
|
|
@ -538,7 +546,7 @@ def _all_support_hold_examples(
|
||||||
supports: List[str] = []
|
supports: List[str] = []
|
||||||
|
|
||||||
for loc, orders in all_orders.items():
|
for loc, orders in all_orders.items():
|
||||||
if holder.endswith(loc): # skip the holding unit
|
if holder.endswith(loc): # skip the holding unit
|
||||||
continue
|
continue
|
||||||
for o in orders:
|
for o in orders:
|
||||||
if o.strip().endswith(target):
|
if o.strip().endswith(target):
|
||||||
|
|
@ -550,6 +558,7 @@ def _all_support_hold_examples(
|
||||||
# Province-type resolver (handles short codes, coasts, seas)
|
# Province-type resolver (handles short codes, coasts, seas)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _province_type_display(game_map, prov_short: str) -> str:
|
def _province_type_display(game_map, prov_short: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
|
||||||
|
|
@ -575,7 +584,7 @@ def _dest_occupancy_desc(
|
||||||
board_state,
|
board_state,
|
||||||
our_power: str,
|
our_power: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
""" '(occupied by X)', '(occupied by X — you!)', or '(unoccupied)' """
|
"""'(occupied by X)', '(occupied by X — you!)', or '(unoccupied)'"""
|
||||||
occupant: Optional[str] = None
|
occupant: Optional[str] = None
|
||||||
for full in game_map.loc_coasts.get(dest_short, [dest_short]):
|
for full in game_map.loc_coasts.get(dest_short, [dest_short]):
|
||||||
u = get_unit_at_location(board_state, full)
|
u = get_unit_at_location(board_state, full)
|
||||||
|
|
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
|
||||||
# Adjacent-territory lines (used by movement-phase builder)
|
# Adjacent-territory lines (used by movement-phase builder)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _adjacent_territory_lines(
|
def _adjacent_territory_lines(
|
||||||
graph,
|
graph,
|
||||||
game_map,
|
game_map,
|
||||||
|
|
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
|
||||||
pwr = unit_here.split(" ")[-1].strip("()")
|
pwr = unit_here.split(" ")[-1].strip("()")
|
||||||
if pwr == our_power:
|
if pwr == our_power:
|
||||||
friend_descr = unit_here.split(" (")[0]
|
friend_descr = unit_here.split(" (")[0]
|
||||||
lines.append(
|
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
|
||||||
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
lines.append(
|
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa")
|
||||||
f"{indent2}-> {unit_here} can support or contest {mover_descr}’s moves and vice-versa"
|
|
||||||
)
|
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
|
||||||
# Movement-phase generator (UNCHANGED LOGIC)
|
# Movement-phase generator (UNCHANGED LOGIC)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_movement(
|
def _generate_rich_order_context_movement(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
|
||||||
if _norm_power(unit_power) != me:
|
if _norm_power(unit_power) != me:
|
||||||
continue # Skip units that aren’t ours
|
continue # Skip units that aren’t ours
|
||||||
|
|
||||||
mover_descr, _ = _split_move(
|
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
|
||||||
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
|
|
||||||
)
|
|
||||||
|
|
||||||
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
|
||||||
prov_type_disp = _province_type_display(game_map, prov_short)
|
prov_type_disp = _province_type_display(game_map, prov_short)
|
||||||
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
sc_tag = " (SC)" if prov_short in game_map.scs else ""
|
||||||
|
|
||||||
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
|
||||||
owner_line = (
|
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
||||||
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ind = " "
|
ind = " "
|
||||||
block: List[str] = [f"<Territory {prov_short}>"]
|
block: List[str] = [f"<Territory {prov_short}>"]
|
||||||
|
|
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
|
||||||
|
|
||||||
# ----- Adjacent territories -----
|
# ----- Adjacent territories -----
|
||||||
block.append("# Adjacent territories:")
|
block.append("# Adjacent territories:")
|
||||||
block.extend(
|
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
|
||||||
_adjacent_territory_lines(
|
|
||||||
graph, game_map, board_state,
|
|
||||||
unit_loc_full, mover_descr, power_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ----- Nearest enemy units -----
|
# ----- Nearest enemy units -----
|
||||||
block.append("# Nearest units (not ours):")
|
block.append("# Nearest units (not ours):")
|
||||||
enemies = get_nearest_enemy_units(
|
enemies = get_nearest_enemy_units(
|
||||||
board_state, graph, game_map,
|
board_state,
|
||||||
power_name, unit_loc_full,
|
graph,
|
||||||
|
game_map,
|
||||||
|
power_name,
|
||||||
|
unit_loc_full,
|
||||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||||
n=3,
|
n=3,
|
||||||
)
|
)
|
||||||
|
|
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
|
||||||
# ----- Nearest uncontrolled SCs -----
|
# ----- Nearest uncontrolled SCs -----
|
||||||
block.append("# Nearest supply centers (not controlled by us):")
|
block.append("# Nearest supply centers (not controlled by us):")
|
||||||
scs = get_nearest_uncontrolled_scs(
|
scs = get_nearest_uncontrolled_scs(
|
||||||
game_map, board_state, graph,
|
game_map,
|
||||||
power_name, unit_loc_full,
|
board_state,
|
||||||
|
graph,
|
||||||
|
power_name,
|
||||||
|
unit_loc_full,
|
||||||
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
"ARMY" if mover_descr.startswith("A") else "FLEET",
|
||||||
n=3,
|
n=3,
|
||||||
)
|
)
|
||||||
|
|
@ -729,7 +733,7 @@ def _generate_rich_order_context_movement(
|
||||||
block.append(f"# Possible {mover_descr} unit movements & supports:")
|
block.append(f"# Possible {mover_descr} unit movements & supports:")
|
||||||
|
|
||||||
simple_moves = [o for o in orders if _is_simple_move(o)]
|
simple_moves = [o for o in orders if _is_simple_move(o)]
|
||||||
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
|
hold_orders = [o for o in orders if _is_hold_order(o)] # NEW
|
||||||
|
|
||||||
if not simple_moves and not hold_orders:
|
if not simple_moves and not hold_orders:
|
||||||
block.append(f"{ind}None")
|
block.append(f"{ind}None")
|
||||||
|
|
@ -737,22 +741,19 @@ def _generate_rich_order_context_movement(
|
||||||
# ---- Moves (same behaviour as before) ----
|
# ---- Moves (same behaviour as before) ----
|
||||||
for mv in simple_moves:
|
for mv in simple_moves:
|
||||||
mover, dest = _split_move(mv)
|
mover, dest = _split_move(mv)
|
||||||
occ = _dest_occupancy_desc(
|
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
|
||||||
dest.split("/")[0][:3], game_map, board_state, power_name
|
|
||||||
)
|
|
||||||
block.append(f"{ind}{mv} {occ}")
|
block.append(f"{ind}{mv} {occ}")
|
||||||
|
|
||||||
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
for s in _all_support_examples(mover, dest, possible_orders_for_power):
|
||||||
block.append(f"{ind*2}Available Support: {s}")
|
block.append(f"{ind * 2}Available Support: {s}")
|
||||||
|
|
||||||
# ---- Holds (new) ----
|
# ---- Holds (new) ----
|
||||||
for hd in hold_orders:
|
for hd in hold_orders:
|
||||||
holder = hd.split(" H")[0] # e.g., 'F DEN'
|
holder = hd.split(" H")[0] # e.g., 'F DEN'
|
||||||
block.append(f"{ind}{hd}")
|
block.append(f"{ind}{hd}")
|
||||||
|
|
||||||
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
for s in _all_support_hold_examples(holder, possible_orders_for_power):
|
||||||
block.append(f"{ind*2}Available Support: {s}")
|
block.append(f"{ind * 2}Available Support: {s}")
|
||||||
|
|
||||||
|
|
||||||
block.append(f"</Territory {prov_short}>")
|
block.append(f"</Territory {prov_short}>")
|
||||||
blocks.append("\n".join(block))
|
blocks.append("\n".join(block))
|
||||||
|
|
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
|
||||||
# Retreat-phase builder – echo orders verbatim, no tags
|
# Retreat-phase builder – echo orders verbatim, no tags
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_retreat(
|
def _generate_rich_order_context_retreat(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
|
||||||
return "\n".join(lines) if lines else "(No dislodged units)"
|
return "\n".join(lines) if lines else "(No dislodged units)"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
# Adjustment-phase builder – summary line + orders, no WAIVEs, no tags
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _generate_rich_order_context_adjustment(
|
def _generate_rich_order_context_adjustment(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -800,12 +802,12 @@ def _generate_rich_order_context_adjustment(
|
||||||
* No wrapper tags.
|
* No wrapper tags.
|
||||||
"""
|
"""
|
||||||
board_state = game.get_state()
|
board_state = game.get_state()
|
||||||
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
|
sc_owned = len(board_state.get("centers", {}).get(power_name, []))
|
||||||
units_num = len(board_state.get("units", {}).get(power_name, []))
|
units_num = len(board_state.get("units", {}).get(power_name, []))
|
||||||
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
|
delta = sc_owned - units_num # +ve ⇒ builds, -ve ⇒ disbands
|
||||||
|
|
||||||
# ----- summary line ----------------------------------------------------
|
# ----- summary line ----------------------------------------------------
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
summary = f"Builds available: {delta}"
|
summary = f"Builds available: {delta}"
|
||||||
elif delta < 0:
|
elif delta < 0:
|
||||||
summary = f"Disbands required: {-delta}"
|
summary = f"Disbands required: {-delta}"
|
||||||
|
|
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
|
||||||
return "\n".join(lines) if len(lines) > 1 else summary
|
return "\n".join(lines) if len(lines) > 1 else summary
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Phase-dispatch wrapper (public entry point)
|
# Phase-dispatch wrapper (public entry point)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def generate_rich_order_context(
|
def generate_rich_order_context(
|
||||||
game: Any,
|
game: Any,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
|
|
@ -846,22 +846,14 @@ def generate_rich_order_context(
|
||||||
|
|
||||||
phase_type = game.current_short_phase[-1]
|
phase_type = game.current_short_phase[-1]
|
||||||
|
|
||||||
if phase_type == "M": # Movement
|
if phase_type == "M": # Movement
|
||||||
return _generate_rich_order_context_movement(
|
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
if phase_type == "R": # Retreat
|
if phase_type == "R": # Retreat
|
||||||
return _generate_rich_order_context_retreat(
|
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
if phase_type == "A": # Adjustment (build / disband)
|
if phase_type == "A": # Adjustment (build / disband)
|
||||||
return _generate_rich_order_context_adjustment(
|
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback – treat unknown formats as movement
|
# Fallback – treat unknown formats as movement
|
||||||
return _generate_rich_order_context_movement(
|
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
|
||||||
game, power_name, possible_orders_for_power
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from .possible_order_context import (
|
||||||
generate_rich_order_context,
|
generate_rich_order_context,
|
||||||
generate_rich_order_context_xml,
|
generate_rich_order_context_xml,
|
||||||
)
|
)
|
||||||
import os
|
|
||||||
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
from .game_history import GameHistory # Assuming GameHistory is correctly importable
|
||||||
|
|
||||||
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
|
||||||
|
|
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
|
||||||
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
|
||||||
|
|
||||||
return final_prompt
|
return final_prompt
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ def atomic_write_json(data: dict, filepath: str):
|
||||||
|
|
||||||
# Write to a temporary file in the same directory
|
# Write to a temporary file in the same directory
|
||||||
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
|
||||||
with open(temp_filepath, 'w', encoding='utf-8') as f:
|
with open(temp_filepath, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
# Atomically rename the temporary file to the final destination
|
# Atomically rename the temporary file to the final destination
|
||||||
|
|
@ -110,11 +110,7 @@ def get_special_models() -> Dict[str, str]:
|
||||||
# Use a different Gemini model for formatting
|
# Use a different Gemini model for formatting
|
||||||
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
|
||||||
"""
|
"""
|
||||||
return {
|
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
|
||||||
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
|
|
||||||
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||||
|
|
@ -132,7 +128,7 @@ def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
|
||||||
|
|
||||||
async def get_valid_orders(
|
async def get_valid_orders(
|
||||||
game: Game,
|
game: Game,
|
||||||
client, # BaseModelClient instance
|
client, # BaseModelClient instance
|
||||||
board_state,
|
board_state,
|
||||||
power_name: str,
|
power_name: str,
|
||||||
possible_orders: Dict[str, List[str]],
|
possible_orders: Dict[str, List[str]],
|
||||||
|
|
@ -170,20 +166,24 @@ async def get_valid_orders(
|
||||||
)
|
)
|
||||||
|
|
||||||
invalid_info: list[str] = []
|
invalid_info: list[str] = []
|
||||||
valid: list[str] = []
|
valid: list[str] = []
|
||||||
invalid: list[str] = []
|
invalid: list[str] = []
|
||||||
|
|
||||||
# ── 2. Type check ──────────────────────────────────────────
|
# ── 2. Type check ──────────────────────────────────────────
|
||||||
if not isinstance(raw_orders, list):
|
if not isinstance(raw_orders, list):
|
||||||
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
|
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
|
||||||
power_name, raw_orders)
|
|
||||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||||
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
|
||||||
|
|
||||||
# ── 3. Round-trip validation with engine ───────────────────
|
# ── 3. Round-trip validation with engine ───────────────────
|
||||||
CODE_TO_ENGINE = {
|
CODE_TO_ENGINE = {
|
||||||
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
|
"AUT": "AUSTRIA",
|
||||||
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
|
"ENG": "ENGLAND",
|
||||||
|
"FRA": "FRANCE",
|
||||||
|
"GER": "GERMANY",
|
||||||
|
"ITA": "ITALY",
|
||||||
|
"RUS": "RUSSIA",
|
||||||
|
"TUR": "TURKEY",
|
||||||
}
|
}
|
||||||
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
|
||||||
|
|
||||||
|
|
@ -202,16 +202,16 @@ async def get_valid_orders(
|
||||||
game.set_orders(engine_power, [upper])
|
game.set_orders(engine_power, [upper])
|
||||||
normed = game.get_orders(engine_power)
|
normed = game.get_orders(engine_power)
|
||||||
|
|
||||||
if normed: # accepted
|
if normed: # accepted
|
||||||
valid.append(normed[0])
|
valid.append(normed[0])
|
||||||
else: # rejected
|
else: # rejected
|
||||||
invalid.append(upper)
|
invalid.append(upper)
|
||||||
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
invalid_info.append(f"Order '{move}' is invalid for {power_name}")
|
||||||
|
|
||||||
game.clear_orders(engine_power) # clean slate for main engine flow
|
game.clear_orders(engine_power) # clean slate for main engine flow
|
||||||
|
|
||||||
# ── 4. Legacy logging & stats updates ──────────────────────
|
# ── 4. Legacy logging & stats updates ──────────────────────
|
||||||
if invalid_info: # at least one bad move
|
if invalid_info: # at least one bad move
|
||||||
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
|
logger.debug("[%s] Invalid orders: %s", power_name, ", ".join(invalid_info))
|
||||||
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
model_error_stats[client.model_name]["order_decoding_errors"] += 1
|
||||||
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
|
logger.debug("[%s] Some orders invalid, using fallback.", power_name)
|
||||||
|
|
@ -226,8 +226,6 @@ async def get_valid_orders(
|
||||||
return {"valid": valid, "invalid": invalid}
|
return {"valid": valid, "invalid": invalid}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_and_compare_orders(
|
def normalize_and_compare_orders(
|
||||||
issued_orders: Dict[str, List[str]],
|
issued_orders: Dict[str, List[str]],
|
||||||
accepted_orders_dict: Dict[str, List[str]],
|
accepted_orders_dict: Dict[str, List[str]],
|
||||||
|
|
@ -325,18 +323,18 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||||
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
3. Elif *prompts_dir* is provided → join prompts_dir + filename.
|
||||||
4. Otherwise → join the package’s default prompts dir.
|
4. Otherwise → join the package’s default prompts dir.
|
||||||
"""
|
"""
|
||||||
if os.path.isabs(filename): # rule 1
|
if os.path.isabs(filename): # rule 1
|
||||||
prompt_path = filename
|
prompt_path = filename
|
||||||
elif os.path.dirname(filename): # rule 2 (has slash)
|
elif os.path.dirname(filename): # rule 2 (has slash)
|
||||||
# If it's a relative path with directory, join with prompts_dir if provided
|
# If it's a relative path with directory, join with prompts_dir if provided
|
||||||
if prompts_dir:
|
if prompts_dir:
|
||||||
prompt_path = os.path.join(prompts_dir, filename)
|
prompt_path = os.path.join(prompts_dir, filename)
|
||||||
else:
|
else:
|
||||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||||
prompt_path = os.path.join(default_dir, filename)
|
prompt_path = os.path.join(default_dir, filename)
|
||||||
elif prompts_dir: # rule 3
|
elif prompts_dir: # rule 3
|
||||||
prompt_path = os.path.join(prompts_dir, filename)
|
prompt_path = os.path.join(prompts_dir, filename)
|
||||||
else: # rule 4
|
else: # rule 4
|
||||||
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
default_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||||
prompt_path = os.path.join(default_dir, filename)
|
prompt_path = os.path.join(default_dir, filename)
|
||||||
|
|
||||||
|
|
@ -348,16 +346,14 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# == New LLM Response Logging Function ==
|
# == New LLM Response Logging Function ==
|
||||||
def log_llm_response(
|
def log_llm_response(
|
||||||
log_file_path: str,
|
log_file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
power_name: Optional[str], # Optional for non-power-specific calls like summary
|
power_name: Optional[str], # Optional for non-power-specific calls like summary
|
||||||
phase: str,
|
phase: str,
|
||||||
response_type: str,
|
response_type: str,
|
||||||
raw_input_prompt: str, # Added new parameter for the raw input
|
raw_input_prompt: str, # Added new parameter for the raw input
|
||||||
raw_response: str,
|
raw_response: str,
|
||||||
success: str, # Changed from bool to str
|
success: str, # Changed from bool to str
|
||||||
):
|
):
|
||||||
|
|
@ -365,8 +361,8 @@ def log_llm_response(
|
||||||
try:
|
try:
|
||||||
# Ensure the directory exists
|
# Ensure the directory exists
|
||||||
log_dir = os.path.dirname(log_file_path)
|
log_dir = os.path.dirname(log_file_path)
|
||||||
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
|
if log_dir: # Ensure log_dir is not empty (e.g., if path is just a filename)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check if file exists and has content to determine if we need headers
|
# Check if file exists and has content to determine if we need headers
|
||||||
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
|
file_exists = os.path.isfile(log_file_path) and os.path.getsize(log_file_path) > 0
|
||||||
|
|
@ -374,34 +370,38 @@ def log_llm_response(
|
||||||
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
|
||||||
# Added "raw_input" to fieldnames
|
# Added "raw_input" to fieldnames
|
||||||
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
|
||||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
|
writer = csv.DictWriter(
|
||||||
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
csvfile,
|
||||||
escapechar='\\') # Use backslash for escaping
|
fieldnames=fieldnames,
|
||||||
|
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
|
||||||
|
escapechar="\\",
|
||||||
|
) # Use backslash for escaping
|
||||||
|
|
||||||
if not file_exists:
|
if not file_exists:
|
||||||
writer.writeheader() # Write header only if file is new
|
writer.writeheader() # Write header only if file is new
|
||||||
|
|
||||||
writer.writerow({
|
writer.writerow(
|
||||||
"model": model_name,
|
{
|
||||||
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
"model": model_name,
|
||||||
"phase": phase,
|
"power": power_name if power_name else "game", # Use 'game' if no specific power
|
||||||
"response_type": response_type,
|
"phase": phase,
|
||||||
"raw_input": raw_input_prompt, # Added raw_input to the row
|
"response_type": response_type,
|
||||||
"raw_response": raw_response,
|
"raw_input": raw_input_prompt, # Added raw_input to the row
|
||||||
"success": success,
|
"raw_response": raw_response,
|
||||||
})
|
"success": success,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def run_llm_and_log(
|
async def run_llm_and_log(
|
||||||
client: "BaseModelClient",
|
client: "BaseModelClient",
|
||||||
prompt: str,
|
prompt: str,
|
||||||
log_file_path: str, # Kept for context, but not used for logging here
|
log_file_path: str, # Kept for context, but not used for logging here
|
||||||
power_name: Optional[str], # Kept for context, but not used for logging here
|
power_name: Optional[str], # Kept for context, but not used for logging here
|
||||||
phase: str, # Kept for context, but not used for logging here
|
phase: str, # Kept for context, but not used for logging here
|
||||||
response_type: str, # Kept for context, but not used for logging here
|
response_type: str, # Kept for context, but not used for logging here
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
*,
|
*,
|
||||||
attempts: int = 5,
|
attempts: int = 5,
|
||||||
|
|
@ -428,35 +428,26 @@ async def run_llm_and_log(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == attempts - 1:
|
if attempt == attempts - 1:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
|
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
|
||||||
f"in phase {phase}: {e}",
|
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Back-off before the next attempt (unless this was the last)
|
# Back-off before the next attempt (unless this was the last)
|
||||||
delay = backoff_base * (backoff_factor ** attempt) + random.uniform(0, jitter)
|
delay = backoff_base * (backoff_factor**attempt) + random.uniform(0, jitter)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
return raw_response
|
return raw_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# This generates a few lines of random alphanum chars to inject into the
|
# This generates a few lines of random alphanum chars to inject into the
|
||||||
# system prompt. This lets us use temp=0 while still getting variation
|
# system prompt. This lets us use temp=0 while still getting variation
|
||||||
# between trials.
|
# between trials.
|
||||||
# Temp=0 is important for better performance on deciding moves, and to
|
# Temp=0 is important for better performance on deciding moves, and to
|
||||||
# ensure valid json outputs.
|
# ensure valid json outputs.
|
||||||
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
|
||||||
# Generate x lines of y random alphanumeric characters
|
# Generate x lines of y random alphanumeric characters
|
||||||
seed_lines = [
|
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
|
||||||
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
|
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
|
||||||
for _ in range(n_lines)
|
return random_seed_block
|
||||||
]
|
|
||||||
random_seed_block = (
|
|
||||||
"<RANDOM SEED PLEASE IGNORE>\n" +
|
|
||||||
"\n".join(seed_lines) +
|
|
||||||
"\n</RANDOM SEED>"
|
|
||||||
)
|
|
||||||
return random_seed_block
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_path(prompt_name: str) -> str:
|
def get_prompt_path(prompt_name: str) -> str:
|
||||||
|
|
@ -473,35 +464,36 @@ def get_prompt_path(prompt_name: str) -> str:
|
||||||
else:
|
else:
|
||||||
return prompt_name
|
return prompt_name
|
||||||
|
|
||||||
|
|
||||||
def normalize_recipient_name(recipient: str) -> str:
|
def normalize_recipient_name(recipient: str) -> str:
|
||||||
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
"""Normalize recipient names to handle LLM typos and abbreviations."""
|
||||||
if not recipient:
|
if not recipient:
|
||||||
return recipient
|
return recipient
|
||||||
|
|
||||||
recipient = recipient.upper().strip()
|
recipient = recipient.upper().strip()
|
||||||
|
|
||||||
# Handle common LLM typos and abbreviations found in data
|
# Handle common LLM typos and abbreviations found in data
|
||||||
name_mapping = {
|
name_mapping = {
|
||||||
'EGMANY': 'GERMANY',
|
"EGMANY": "GERMANY",
|
||||||
'GERMAN': 'GERMANY',
|
"GERMAN": "GERMANY",
|
||||||
'UK': 'ENGLAND',
|
"UK": "ENGLAND",
|
||||||
'BRIT': 'ENGLAND',
|
"BRIT": "ENGLAND",
|
||||||
'ENGLAND': 'ENGLAND', # Keep as-is
|
"ENGLAND": "ENGLAND", # Keep as-is
|
||||||
'FRANCE': 'FRANCE', # Keep as-is
|
"FRANCE": "FRANCE", # Keep as-is
|
||||||
'GERMANY': 'GERMANY', # Keep as-is
|
"GERMANY": "GERMANY", # Keep as-is
|
||||||
'ITALY': 'ITALY', # Keep as-is
|
"ITALY": "ITALY", # Keep as-is
|
||||||
'AUSTRIA': 'AUSTRIA', # Keep as-is
|
"AUSTRIA": "AUSTRIA", # Keep as-is
|
||||||
'RUSSIA': 'RUSSIA', # Keep as-is
|
"RUSSIA": "RUSSIA", # Keep as-is
|
||||||
'TURKEY': 'TURKEY', # Keep as-is
|
"TURKEY": "TURKEY", # Keep as-is
|
||||||
'Germany': 'GERMANY',
|
"Germany": "GERMANY",
|
||||||
'England': 'ENGLAND',
|
"England": "ENGLAND",
|
||||||
'France': 'FRANCE',
|
"France": "FRANCE",
|
||||||
'Italy': 'ITALY',
|
"Italy": "ITALY",
|
||||||
'Russia': 'RUSSIA',
|
"Russia": "RUSSIA",
|
||||||
'Austria': 'AUSTRIA',
|
"Austria": "AUSTRIA",
|
||||||
'Turkey': 'TURKEY',
|
"Turkey": "TURKEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
normalized = name_mapping.get(recipient, recipient)
|
normalized = name_mapping.get(recipient, recipient)
|
||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue