Ruff format and check

This commit is contained in:
Tyler Marques 2025-07-03 14:27:50 -07:00
parent aadedd6512
commit ad817695ef
No known key found for this signature in database
GPG key ID: CB99EDCF41D3016F
13 changed files with 738 additions and 973 deletions

View file

@ -10,7 +10,7 @@ import ast
from ..config import config
# Assuming BaseModelClient is importable from clients.py in the same directory
from .clients import BaseModelClient, load_model_client
from .clients import BaseModelClient
# Import load_prompt and the new logging wrapper from utils
from .utils import load_prompt, run_llm_and_log, log_llm_response, get_prompt_path
@ -336,7 +336,7 @@ class DiplomacyAgent:
)
# If even the last resort doesn't give a dict, return empty dict
return {}
except Exception as e:
except Exception:
logger.error(f"[{self.power_name}] All JSON extraction attempts failed. Original text: {original_text[:500]}...")
return {}
@ -965,7 +965,7 @@ class DiplomacyAgent:
other_powers = [p for p in game.powers if p != power_name]
# Create a readable board state string from the board_state dict
board_state_str = f"Board State:\n"
board_state_str = "Board State:\n"
for p_name, power_data in board_state.get("powers", {}).items():
# Get units and centers from the board state
units = power_data.get("units", [])

View file

@ -1,12 +1,11 @@
import os
import json
from json import JSONDecodeError
import re
import logging
import ast # For literal_eval in JSON fallback parsing
import aiohttp # For direct HTTP requests to Responses API
from typing import List, Dict, Optional, Any, Tuple, NamedTuple
from typing import List, Dict, Optional, Tuple, NamedTuple
from dotenv import load_dotenv
# Use Async versions of clients
@ -21,9 +20,9 @@ from together import AsyncTogether
from together.error import APIError as TogetherAPIError # For specific error handling
from ..config import config
from diplomacy.engine.message import GLOBAL
from .game_history import GameHistory
from .utils import load_prompt, run_llm_and_log, log_llm_response, generate_random_seed, get_prompt_path
# Import DiplomacyAgent for type hinting if needed, but avoid circular import if possible
from .prompt_constructor import construct_order_generation_prompt, build_context_prompt
# Moved formatter imports to avoid circular import - imported locally where needed
@ -36,6 +35,7 @@ logger.setLevel(logging.DEBUG) # Keep debug for now during async changes
load_dotenv()
##############################################################################
# 1) Base Interface
##############################################################################
@ -118,24 +118,19 @@ class BaseModelClient:
log_file_path=log_file_path,
power_name=power_name,
phase=phase,
response_type='order', # Context for run_llm_and_log's own error logging
temperature=0
)
logger.debug(
f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}"
response_type="order", # Context for run_llm_and_log's own error logging
temperature=0,
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} orders:\n{raw_response}")
# Conditionally format the response based on USE_UNFORMATTED_PROMPTS
if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_ORDERS
# Format the natural language response into structured format
formatted_response = await format_with_gemini_flash(
raw_response,
FORMAT_ORDERS,
power_name=power_name,
phase=phase,
log_file_path=log_file_path
raw_response, FORMAT_ORDERS, power_name=power_name, phase=phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
@ -145,9 +140,7 @@ class BaseModelClient:
move_list = self._extract_moves(formatted_response, power_name)
if not move_list:
logger.warning(
f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback."
)
logger.warning(f"[{self.model_name}] Could not extract moves for {power_name}. Using fallback.")
if model_error_stats is not None and self.model_name in model_error_stats:
model_error_stats[self.model_name].setdefault("order_decoding_errors", 0)
model_error_stats[self.model_name]["order_decoding_errors"] += 1
@ -195,7 +188,7 @@ class BaseModelClient:
response_type="order_generation", # Specific type for CSV logging
raw_input_prompt=prompt, # Renamed from 'prompt' to match log_llm_response arg
raw_response=raw_response,
success=success_status
success=success_status,
# token_usage and cost can be added later if available and if log_llm_response supports them
)
return parsed_orders_for_return
@ -216,9 +209,7 @@ class BaseModelClient:
if not matches:
# Some LLMs might not put the colon or might have triple backtick fences.
logger.debug(
f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns."
)
logger.debug(f"[{self.model_name}] Regex parse #1 failed for {power_name}. Trying alternative patterns.")
# 1b) Check for inline JSON after "PARSABLE OUTPUT"
pattern_alt = r"PARSABLE OUTPUT\s*\{(.*?)\}\s*$"
@ -226,47 +217,35 @@ class BaseModelClient:
if not matches:
# 1c) Check for **PARSABLE OUTPUT:** pattern (with asterisks)
logger.debug(
f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern."
)
logger.debug(f"[{self.model_name}] Regex parse #2 failed for {power_name}. Trying asterisk-wrapped pattern.")
pattern_asterisk = r"\*\*PARSABLE OUTPUT:\*\*\s*(\{[\s\S]*?\})"
matches = re.search(pattern_asterisk, raw_response, re.DOTALL)
if not matches:
logger.debug(
f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences."
)
logger.debug(f"[{self.model_name}] Regex parse #3 failed for {power_name}. Trying triple-backtick code fences.")
# 2) If still no match, check for triple-backtick code fences containing JSON
if not matches:
code_fence_pattern = r"```json\n(.*?)\n```"
matches = re.search(code_fence_pattern, raw_response, re.DOTALL)
if matches:
logger.debug(
f"[{self.model_name}] Found triple-backtick JSON block for {power_name}."
)
logger.debug(f"[{self.model_name}] Found triple-backtick JSON block for {power_name}.")
# 2b) Also try plain ``` code fences without json marker
if not matches:
code_fence_plain = r"```\n(.*?)\n```"
matches = re.search(code_fence_plain, raw_response, re.DOTALL)
if matches:
logger.debug(
f"[{self.model_name}] Found plain triple-backtick block for {power_name}."
)
logger.debug(f"[{self.model_name}] Found plain triple-backtick block for {power_name}.")
# 2c) Try to find bare JSON object anywhere in the response
if not matches:
logger.debug(
f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON."
)
logger.debug(f"[{self.model_name}] No explicit markers found for {power_name}. Looking for bare JSON.")
# Look for a JSON object that contains "orders" key
bare_json_pattern = r'(\{[^{}]*"orders"\s*:\s*\[[^\]]*\][^{}]*\})'
matches = re.search(bare_json_pattern, raw_response, re.DOTALL)
if matches:
logger.debug(
f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}."
)
logger.debug(f"[{self.model_name}] Found bare JSON object with 'orders' key for {power_name}.")
# 3) Attempt to parse JSON if we found anything
json_text = None
@ -283,9 +262,7 @@ class BaseModelClient:
json_text = json_text.strip()
if not json_text:
logger.debug(
f"[{self.model_name}] No JSON text found in LLM response for {power_name}."
)
logger.debug(f"[{self.model_name}] No JSON text found in LLM response for {power_name}.")
return None
# 3a) Try JSON loading
@ -293,14 +270,12 @@ class BaseModelClient:
data = json.loads(json_text)
return data.get("orders", None)
except json.JSONDecodeError as e:
logger.warning(
f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues."
)
logger.warning(f"[{self.model_name}] JSON decode failed for {power_name}: {e}. Trying to fix common issues.")
# Try to fix common JSON issues
try:
# Remove trailing commas
fixed_json = re.sub(r',\s*([\}\]])', r'\1', json_text)
fixed_json = re.sub(r",\s*([\}\]])", r"\1", json_text)
# Fix single quotes to double quotes
fixed_json = fixed_json.replace("'", '"')
# Try parsing again
@ -308,14 +283,12 @@ class BaseModelClient:
logger.info(f"[{self.model_name}] Successfully parsed JSON after fixes for {power_name}")
return data.get("orders", None)
except json.JSONDecodeError:
logger.warning(
f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments."
)
logger.warning(f"[{self.model_name}] JSON decode still failed after fixes for {power_name}. Trying to remove inline comments.")
# Try to remove inline comments (// style)
try:
# Remove // comments from each line
lines = json_text.split('\n')
lines = json_text.split("\n")
cleaned_lines = []
for line in lines:
# Find // that's not inside quotes
@ -326,12 +299,12 @@ class BaseModelClient:
if escape_next:
escape_next = False
continue
if char == '\\':
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_quotes = not in_quotes
if not in_quotes and line[i:i+2] == '//':
if not in_quotes and line[i : i + 2] == "//":
comment_pos = i
break
@ -342,17 +315,15 @@ class BaseModelClient:
cleaned_line = line
cleaned_lines.append(cleaned_line)
comment_free_json = '\n'.join(cleaned_lines)
comment_free_json = "\n".join(cleaned_lines)
# Also remove trailing commas after comment removal
comment_free_json = re.sub(r',\s*([\}\]])', r'\1', comment_free_json)
comment_free_json = re.sub(r",\s*([\}\]])", r"\1", comment_free_json)
data = json.loads(comment_free_json)
logger.info(f"[{self.model_name}] Successfully parsed JSON after removing inline comments for {power_name}")
return data.get("orders", None)
except json.JSONDecodeError:
logger.warning(
f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback."
)
logger.warning(f"[{self.model_name}] JSON decode still failed after removing comments for {power_name}. Trying bracket fallback.")
# 3b) Attempt bracket fallback: we look for the substring after "orders"
# E.g. "orders: ['A BUD H']" and parse it. This is risky but can help with minor JSON format errors.
@ -366,16 +337,12 @@ class BaseModelClient:
if isinstance(moves, list):
return moves
except Exception as e2:
logger.warning(
f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}"
)
logger.warning(f"[{self.model_name}] Bracket fallback parse also failed for {power_name}: {e2}")
# If all attempts failed
return None
def _validate_orders(
self, moves: List[str], possible_orders: Dict[str, List[str]]
) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
def _validate_orders(self, moves: List[str], possible_orders: Dict[str, List[str]]) -> Tuple[List[str], List[str]]: # MODIFIED RETURN TYPE
"""
Filter out invalid moves, fill missing with HOLD, else fallback.
Returns a tuple: (validated_moves, invalid_moves_found)
@ -405,15 +372,15 @@ class BaseModelClient:
for loc, orders_list in possible_orders.items():
if loc not in used_locs and orders_list:
hold_candidates = [o for o in orders_list if o.endswith("H")]
validated.append(
hold_candidates[0] if hold_candidates else orders_list[0]
)
validated.append(hold_candidates[0] if hold_candidates else orders_list[0])
if not validated and not invalid_moves_found: # Only if LLM provided no valid moves and no invalid moves (e.g. empty list from LLM)
logger.warning(f"[{self.model_name}] No valid LLM moves provided and no invalid ones to report. Using fallback.")
return self.fallback_orders(possible_orders), []
elif not validated and invalid_moves_found: # All LLM moves were invalid
logger.warning(f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}")
logger.warning(
f"[{self.model_name}] All LLM moves invalid ({len(invalid_moves_found)} found), using fallback. Invalid: {invalid_moves_found}"
)
# We return empty list for validated, but the invalid_moves_found list is populated
return self.fallback_orders(possible_orders), invalid_moves_found
@ -444,7 +411,6 @@ class BaseModelClient:
agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None, # Added
) -> str:
instructions = load_prompt("planning_instructions.txt", prompts_dir=self.prompts_dir)
context = self.build_context_prompt(
@ -508,7 +474,15 @@ class BaseModelClient:
unanswered_messages += "\nNo urgent messages requiring direct responses.\n"
final_prompt = context + unanswered_messages + "\n\n" + instructions
final_prompt = final_prompt.replace('AUSTRIA', 'Austria').replace('ENGLAND', "England").replace('FRANCE', 'France').replace('GERMANY', 'Germany').replace('ITALY', "Italy").replace('RUSSIA', 'Russia').replace('TURKEY', 'Turkey')
final_prompt = (
final_prompt.replace("AUSTRIA", "Austria")
.replace("ENGLAND", "England")
.replace("FRANCE", "France")
.replace("GERMANY", "Germany")
.replace("ITALY", "Italy")
.replace("RUSSIA", "Russia")
.replace("TURKEY", "Turkey")
)
return final_prompt
async def get_planning_reply( # Renamed from get_plan to avoid conflict with get_plan in agent.py
@ -524,7 +498,6 @@ class BaseModelClient:
agent_relationships: Optional[Dict[str, str]] = None,
agent_private_diary_str: Optional[str] = None, # Added
) -> str:
prompt = self.build_planning_prompt(
game,
board_state,
@ -545,7 +518,7 @@ class BaseModelClient:
log_file_path=log_file_path,
power_name=power_name,
phase=game_phase, # Use game_phase for logging
response_type='plan_reply', # Changed from 'plan' to avoid confusion
response_type="plan_reply", # Changed from 'plan' to avoid confusion
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} planning reply:\n{raw_response}")
return raw_response
@ -592,7 +565,7 @@ class BaseModelClient:
log_file_path=log_file_path,
power_name=power_name,
phase=game_phase,
response_type='negotiation', # For run_llm_and_log's internal context
response_type="negotiation", # For run_llm_and_log's internal context
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name}:\n{raw_response}")
@ -600,13 +573,10 @@ class BaseModelClient:
if config.USE_UNFORMATTED_PROMPTS:
# Local import to avoid circular dependency
from .formatter import format_with_gemini_flash, FORMAT_CONVERSATION
# Format the natural language response into structured JSON
formatted_response = await format_with_gemini_flash(
raw_response,
FORMAT_CONVERSATION,
power_name=power_name,
phase=game_phase,
log_file_path=log_file_path
raw_response, FORMAT_CONVERSATION, power_name=power_name, phase=game_phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
@ -632,10 +602,10 @@ class BaseModelClient:
# Original parsing logic as fallback
if not parsed_messages:
# Attempt to find blocks enclosed in {{...}}
double_brace_blocks = re.findall(r'\{\{(.*?)\}\}', raw_response, re.DOTALL)
double_brace_blocks = re.findall(r"\{\{(.*?)\}\}", raw_response, re.DOTALL)
if double_brace_blocks:
# If {{...}} blocks are found, assume each is a self-contained JSON object
json_blocks.extend(['{' + block.strip() + '}' for block in double_brace_blocks])
json_blocks.extend(["{" + block.strip() + "}" for block in double_brace_blocks])
else:
# If no {{...}} blocks, look for ```json ... ``` markdown blocks
code_block_match = re.search(r"```json\n(.*?)\n```", raw_response, re.DOTALL)
@ -650,10 +620,10 @@ class BaseModelClient:
json_blocks = [json.dumps(data)]
except json.JSONDecodeError:
# If parsing the whole block fails, fall back to regex for individual objects
json_blocks = re.findall(r'\{.*?\}', potential_json_array_or_objects, re.DOTALL)
json_blocks = re.findall(r"\{.*?\}", potential_json_array_or_objects, re.DOTALL)
else:
# If no markdown block, fall back to regex for any JSON object in the response
json_blocks = re.findall(r'\{.*?\}', raw_response, re.DOTALL)
json_blocks = re.findall(r"\{.*?\}", raw_response, re.DOTALL)
# Process json_blocks if we have them from fallback parsing
if not parsed_messages and json_blocks:
@ -661,7 +631,7 @@ class BaseModelClient:
try:
cleaned_block = block.strip()
# Attempt to fix common JSON issues like trailing commas before parsing
cleaned_block = re.sub(r',\s*([\}\]])', r'\1', cleaned_block)
cleaned_block = re.sub(r",\s*([\}\]])", r"\1", cleaned_block)
parsed_message = json.loads(cleaned_block)
parsed_messages.append(parsed_message)
except json.JSONDecodeError as e:
@ -710,7 +680,7 @@ class BaseModelClient:
response_type="negotiation_message",
raw_input_prompt=raw_input_prompt,
raw_response=raw_response,
success=success_status
success=success_status,
)
return messages_to_return
@ -770,7 +740,7 @@ class BaseModelClient:
log_file_path=log_file_path,
power_name=power_name,
phase=game.current_short_phase,
response_type='plan_generation', # More specific type for run_llm_and_log context
response_type="plan_generation", # More specific type for run_llm_and_log context
)
logger.debug(f"[{self.model_name}] Raw LLM response for {power_name} plan generation:\n{raw_plan_response}")
# No parsing needed for the plan, return the raw string
@ -790,7 +760,7 @@ class BaseModelClient:
response_type="plan_generation", # Specific type for CSV logging
raw_input_prompt=full_prompt, # Renamed from 'full_prompt' to match log_llm_response arg
raw_response=raw_plan_response,
success=success_status
success=success_status,
# token_usage and cost can be added later
)
return plan_to_return
@ -813,11 +783,7 @@ class OpenAIClient(BaseModelClient):
):
super().__init__(model_name, prompts_dir=prompts_dir)
self.base_url = (
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
@ -832,11 +798,7 @@ class OpenAIClient(BaseModelClient):
inject_random_seed: bool = True,
) -> str:
try:
system_prompt_content = (
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
prompt_with_cta = f"{prompt}\n\nPROVIDE YOUR RESPONSE BELOW:"
response = await self.client.chat.completions.create(
@ -888,20 +850,14 @@ class ClaudeClient(BaseModelClient):
temperature=temperature,
)
if not response.content:
logger.warning(
f"[{self.model_name}] Empty content in Claude generate_response. Returning empty."
)
logger.warning(f"[{self.model_name}] Empty content in Claude generate_response. Returning empty.")
return ""
return response.content[0].text.strip() if response.content else ""
except json.JSONDecodeError as json_err:
logger.error(
f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}"
)
logger.error(f"[{self.model_name}] JSON decoding failed in generate_response: {json_err}")
return ""
except Exception as e:
logger.error(
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@ -929,19 +885,14 @@ class GeminiClient(BaseModelClient):
full_prompt = system_prompt_content + prompt + "\n\nPROVIDE YOUR RESPONSE BELOW:"
try:
generation_config = genai.types.GenerationConfig(
temperature=temperature,
max_output_tokens=self.max_tokens
)
generation_config = genai.types.GenerationConfig(temperature=temperature, max_output_tokens=self.max_tokens)
response = await self.client.generate_content_async(
contents=full_prompt,
generation_config=generation_config,
)
if not response or not response.text:
logger.warning(
f"[{self.model_name}] Empty Gemini generate_response. Returning empty."
)
logger.warning(f"[{self.model_name}] Empty Gemini generate_response. Returning empty.")
return ""
return response.text.strip()
except Exception as e:
@ -957,10 +908,7 @@ class DeepSeekClient(BaseModelClient):
def __init__(self, model_name: str, prompts_dir: Optional[str] = None):
super().__init__(model_name, prompts_dir=prompts_dir)
self.api_key = os.environ.get("DEEPSEEK_API_KEY")
self.client = AsyncDeepSeekOpenAI(
api_key=self.api_key,
base_url="https://api.deepseek.com/"
)
self.client = AsyncDeepSeekOpenAI(api_key=self.api_key, base_url="https://api.deepseek.com/")
async def generate_response(self, prompt: str, temperature: float = 0.0, inject_random_seed: bool = True) -> str:
try:
@ -986,9 +934,7 @@ class DeepSeekClient(BaseModelClient):
logger.debug(f"[{self.model_name}] Raw DeepSeek response:\n{response}")
if not response or not response.choices:
logger.warning(
f"[{self.model_name}] No valid response in generate_response."
)
logger.warning(f"[{self.model_name}] No valid response in generate_response.")
return ""
content = response.choices[0].message.content.strip()
@ -998,9 +944,7 @@ class DeepSeekClient(BaseModelClient):
return content
except Exception as e:
logger.error(
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@ -1037,19 +981,14 @@ class OpenAIResponsesClient(BaseModelClient):
"max_tokens": self.max_tokens,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
# Make the API call using aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(self.base_url, json=payload, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(
f"[{self.model_name}] API error (status {response.status}): {error_text}"
)
logger.error(f"[{self.model_name}] API error (status {response.status}): {error_text}")
return ""
response_data = await response.json()
@ -1059,24 +998,18 @@ class OpenAIResponsesClient(BaseModelClient):
try:
outputs = response_data.get("output", [])
if len(outputs) < 2:
logger.warning(
f"[{self.model_name}] Unexpected output structure. Full response: {response_data}"
)
logger.warning(f"[{self.model_name}] Unexpected output structure. Full response: {response_data}")
return ""
# The message is typically in the second output item
message_output = outputs[1]
if message_output.get("type") != "message":
logger.warning(
f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}"
)
logger.warning(f"[{self.model_name}] Expected message type in output[1]. Got: {message_output.get('type')}")
return ""
content_list = message_output.get("content", [])
if not content_list:
logger.warning(
f"[{self.model_name}] Empty content list in message output"
)
logger.warning(f"[{self.model_name}] Empty content list in message output")
return ""
# Look for the content item with type 'output_text'
@ -1087,28 +1020,20 @@ class OpenAIResponsesClient(BaseModelClient):
break
if not text_content:
logger.warning(
f"[{self.model_name}] No output_text found in content. Full content: {content_list}"
)
logger.warning(f"[{self.model_name}] No output_text found in content. Full content: {content_list}")
return ""
return text_content.strip()
except (KeyError, IndexError, TypeError) as e:
logger.error(
f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}"
)
logger.error(f"[{self.model_name}] Error parsing response structure: {e}. Full response: {response_data}")
return ""
except aiohttp.ClientError as e:
logger.error(
f"[{self.model_name}] HTTP client error in generate_response: {e}"
)
logger.error(f"[{self.model_name}] HTTP client error in generate_response: {e}")
return ""
except Exception as e:
logger.error(
f"[{self.model_name}] Unexpected error in generate_response: {e}"
)
logger.error(f"[{self.model_name}] Unexpected error in generate_response: {e}")
return ""
@ -1129,10 +1054,7 @@ class OpenRouterClient(BaseModelClient):
if not self.api_key:
raise ValueError("OPENROUTER_API_KEY environment variable is required")
self.client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=self.api_key
)
self.client = AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=self.api_key)
logger.debug(f"[{self.model_name}] Initialized OpenRouter client")
@ -1150,10 +1072,7 @@ class OpenRouterClient(BaseModelClient):
# Prepare standard OpenAI-compatible request
response = await self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt_content},
{"role": "user", "content": prompt_with_cta}
],
messages=[{"role": "system", "content": system_prompt_content}, {"role": "user", "content": prompt_with_cta}],
max_tokens=self.max_tokens,
temperature=temperature,
)
@ -1245,7 +1164,7 @@ class TogetherAIClient(BaseModelClient):
##############################################################################
# RequestsOpenAIClient sync requests, wrapped async (original + api_key)
##############################################################################
import requests, asyncio
class RequestsOpenAIClient(BaseModelClient):
"""
@ -1266,11 +1185,7 @@ class RequestsOpenAIClient(BaseModelClient):
if not self.api_key:
raise ValueError("OPENAI_API_KEY missing and no inline key provided")
self.base_url = (
base_url
or os.environ.get("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
).rstrip("/")
self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1").rstrip("/")
self.endpoint = f"{self.base_url}/chat/completions"
@ -1291,11 +1206,7 @@ class RequestsOpenAIClient(BaseModelClient):
temperature: float = 0.0,
inject_random_seed: bool = True,
) -> str:
system_prompt_content = (
f"{generate_random_seed()}\n\n{self.system_prompt}"
if inject_random_seed
else self.system_prompt
)
system_prompt_content = f"{generate_random_seed()}\n\n{self.system_prompt}" if inject_random_seed else self.system_prompt
payload = {
"model": self.model_name,
@ -1322,7 +1233,6 @@ class RequestsOpenAIClient(BaseModelClient):
return ""
##############################################################################
# 3) Factory to Load Model Client
##############################################################################
@ -1332,6 +1242,7 @@ class ModelSpec(NamedTuple):
base: Optional[str] # 'https://proxy.foo'
key: Optional[str] # 'sk-…' (may be None)
def _parse_model_spec(raw: str) -> ModelSpec:
"""
Splits once on '#' (API key) and once on '@' (base URL). A leading
@ -1441,8 +1352,6 @@ def load_model_client(model_id: str, prompts_dir: Optional[str] = None) -> BaseM
return OpenAIClient(model_name, prompts_dir, base_url)
##############################################################################
# 1) Add a method to filter visible messages (near top-level or in BaseModelClient)
##############################################################################
@ -1453,11 +1362,6 @@ def get_visible_messages_for_power(conversation_messages, power_name):
visible = []
for msg in conversation_messages:
# GLOBAL might be 'ALL' or 'GLOBAL' depending on your usage
if (
msg["recipient"] == "ALL"
or msg["recipient"] == "GLOBAL"
or msg["sender"] == power_name
or msg["recipient"] == power_name
):
if msg["recipient"] == "ALL" or msg["recipient"] == "GLOBAL" or msg["sender"] == power_name or msg["recipient"] == power_name:
visible.append(msg)
return visible # already in chronological order if appended that way

View file

@ -11,25 +11,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _load_prompt_file(filename: str, prompts_dir: Optional[str] = None) -> str | None:
"""A local copy of the helper from agent.py to avoid circular imports."""
import os
try:
if prompts_dir:
filepath = os.path.join(prompts_dir, filename)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
default_prompts_dir = os.path.join(current_dir, 'prompts')
default_prompts_dir = os.path.join(current_dir, "prompts")
filepath = os.path.join(default_prompts_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
logger.error(f"Error loading prompt file {filepath}: {e}")
return None
async def run_diary_consolidation(
agent: 'DiplomacyAgent',
agent: "DiplomacyAgent",
game: "Game",
log_file_path: str,
entries_to_keep_unsummarized: int = 6,
@ -39,70 +42,42 @@ async def run_diary_consolidation(
Consolidate older diary entries while keeping recent ones.
This is the logic moved from the DiplomacyAgent class.
"""
logger.info(
f"[{agent.power_name}] CONSOLIDATION START — "
f"{len(agent.full_private_diary)} total full entries"
)
logger.info(f"[{agent.power_name}] CONSOLIDATION START — {len(agent.full_private_diary)} total full entries")
full_entries = [
e for e in agent.full_private_diary
if not e.startswith("[CONSOLIDATED HISTORY]")
]
full_entries = [e for e in agent.full_private_diary if not e.startswith("[CONSOLIDATED HISTORY]")]
if len(full_entries) <= entries_to_keep_unsummarized:
agent.private_diary = list(agent.full_private_diary)
logger.info(
f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — "
"skipping consolidation"
)
logger.info(f"[{agent.power_name}] ≤ {entries_to_keep_unsummarized} full entries — skipping consolidation")
return
boundary_entry = full_entries[-entries_to_keep_unsummarized]
match = re.search(r"\[[SFWRAB]\s*(\d{4})", boundary_entry)
if not match:
logger.error(
f"[{agent.power_name}] Could not parse year from boundary entry; "
"aborting consolidation"
)
logger.error(f"[{agent.power_name}] Could not parse year from boundary entry; aborting consolidation")
agent.private_diary = list(agent.full_private_diary)
return
cutoff_year = int(match.group(1))
logger.info(
f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}"
)
logger.info(f"[{agent.power_name}] Cut-off year for consolidation: {cutoff_year}")
def _entry_year(entry: str) -> int | None:
m = re.search(r"\[[SFWRAB]\s*(\d{4})", entry)
return int(m.group(1)) if m else None
entries_to_summarize = [
e for e in full_entries
if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)
]
entries_to_keep = [
e for e in full_entries
if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)
]
entries_to_summarize = [e for e in full_entries if (_entry_year(e) is not None and _entry_year(e) < cutoff_year)]
entries_to_keep = [e for e in full_entries if (_entry_year(e) is None or _entry_year(e) >= cutoff_year)]
logger.info(
f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; "
f"keeping {len(entries_to_keep)} recent entries verbatim"
)
logger.info(f"[{agent.power_name}] Summarising {len(entries_to_summarize)} entries; keeping {len(entries_to_keep)} recent entries verbatim")
if not entries_to_summarize:
agent.private_diary = list(agent.full_private_diary)
logger.warning(
f"[{agent.power_name}] No eligible entries to summarise; "
"context diary left unchanged"
)
logger.warning(f"[{agent.power_name}] No eligible entries to summarise; context diary left unchanged")
return
prompt_template = _load_prompt_file("diary_consolidation_prompt.txt", prompts_dir=prompts_dir)
if not prompt_template:
logger.error(
f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting"
)
logger.error(f"[{agent.power_name}] diary_consolidation_prompt.txt missing — aborting")
return
prompt = prompt_template.format(
@ -132,23 +107,14 @@ async def run_diary_consolidation(
new_summary_entry = f"[CONSOLIDATED HISTORY] {consolidated_text}"
agent.private_diary = [new_summary_entry] + entries_to_keep
success_flag = "TRUE"
logger.info(
f"[{agent.power_name}] Consolidation complete — "
f"{len(agent.private_diary)} context entries now"
)
logger.info(f"[{agent.power_name}] Consolidation complete — {len(agent.private_diary)} context entries now")
except Exception as exc:
logger.error(
f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True
)
logger.error(f"[{agent.power_name}] Diary consolidation failed: {exc}", exc_info=True)
finally:
log_llm_response(
log_file_path=log_file_path,
model_name=(
consolidation_client.model_name
if consolidation_client is not None
else agent.client.model_name
),
model_name=(consolidation_client.model_name if consolidation_client is not None else agent.client.model_name),
power_name=agent.power_name,
phase=game.current_short_phase,
response_type="diary_consolidation",

View file

@ -3,14 +3,13 @@ Formatter module for converting natural language LLM responses to structured JSO
Uses Gemini 2.5 Flash via OpenRouter to extract and format information from reasoning-focused responses.
"""
import json
import logging
import os
from typing import Dict, Any, Optional
from typing import Optional
from pathlib import Path
# Import logging function and model configuration
from .utils import log_llm_response, get_special_models
# Import client loading function
from .clients import load_model_client
@ -26,11 +25,7 @@ FORMAT_ORDER_DIARY = "order_diary"
async def format_with_gemini_flash(
raw_response: str,
format_type: str,
power_name: Optional[str] = None,
phase: Optional[str] = None,
log_file_path: Optional[str] = None
raw_response: str, format_type: str, power_name: Optional[str] = None, phase: Optional[str] = None, log_file_path: Optional[str] = None
) -> str:
"""
Format a natural language response into required JSON structure using Gemini Flash.
@ -52,7 +47,7 @@ async def format_with_gemini_flash(
FORMAT_NEGOTIATION_DIARY: "formatting/format_negotiation_diary.txt",
FORMAT_ORDERS: "formatting/format_orders.txt",
FORMAT_INITIAL_STATE: "formatting/format_initial_state.txt",
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt"
FORMAT_ORDER_DIARY: "formatting/format_order_diary.txt",
}
if format_type not in format_prompts:
@ -63,7 +58,7 @@ async def format_with_gemini_flash(
if not prompt_file.exists():
raise FileNotFoundError(f"Formatting prompt not found: {prompt_file}")
with open(prompt_file, 'r') as f:
with open(prompt_file, "r") as f:
format_prompt = f.read()
# Replace placeholder with actual response
@ -87,7 +82,7 @@ async def format_with_gemini_flash(
formatted_response = await formatter_client.generate_response(
prompt=format_prompt,
temperature=0, # Deterministic formatting
inject_random_seed=False # No need for random seed in formatting
inject_random_seed=False, # No need for random seed in formatting
)
if not formatted_response:
@ -116,7 +111,7 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=formatted_response,
success="Success"
success="Success",
)
return formatted_response
@ -134,7 +129,7 @@ async def format_with_gemini_flash(
response_type=f"format_{format_type}",
raw_input_prompt=format_prompt,
raw_response=f"ERROR: {str(e)}",
success=f"Failure: {type(e).__name__}"
success=f"Failure: {type(e).__name__}",
)
# Return empty structure based on format type

View file

@ -3,7 +3,6 @@ import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from pprint import pformat
logger = logging.getLogger("utils")
logger.setLevel(logging.INFO)
@ -24,12 +23,8 @@ class Phase:
plans: Dict[str, str] = field(default_factory=dict)
messages: List[Message] = field(default_factory=list)
submitted_orders_by_power: Dict[str, List[str]] = field(default_factory=dict)
orders_by_power: Dict[str, List[str]] = field(
default_factory=lambda: defaultdict(list)
)
results_by_power: Dict[str, List[List[str]]] = field(
default_factory=lambda: defaultdict(list)
)
orders_by_power: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list))
results_by_power: Dict[str, List[List[str]]] = field(default_factory=lambda: defaultdict(list))
# NEW: Store phase-end summaries provided by each power
phase_summaries: Dict[str, str] = field(default_factory=dict)
# NEW: Store experience/journal updates from each power for this phase
@ -39,9 +34,7 @@ class Phase:
self.plans[power_name] = plan
def add_message(self, sender: str, recipient: str, content: str):
self.messages.append(
Message(sender=sender, recipient=recipient, content=content)
)
self.messages.append(Message(sender=sender, recipient=recipient, content=content))
def add_orders(self, power: str, orders: List[str], results: List[List[str]]):
self.orders_by_power[power].extend(orders)
@ -110,14 +103,10 @@ class GameHistory:
phase.plans[power_name] = plan
logger.debug(f"Added plan for {power_name} in {phase_name}")
def add_message(
self, phase_name: str, sender: str, recipient: str, message_content: str
):
def add_message(self, phase_name: str, sender: str, recipient: str, message_content: str):
phase = self._get_phase(phase_name)
if phase:
message = Message(
sender=sender, recipient=recipient, content=message_content
)
message = Message(sender=sender, recipient=recipient, content=message_content)
phase.messages.append(message)
logger.debug(f"Added message from {sender} to {recipient} in {phase_name}")
@ -153,10 +142,6 @@ class GameHistory:
return {}
return self.phases[-1].plans
def get_order_history_for_prompt(
self,
game: "Game",
@ -194,10 +179,7 @@ class GameHistory:
return str(tag).strip().lower()
engine_phases = {ph.name: ph for ph in getattr(game, "get_phase_history", lambda: [])()}
eng2code = {
"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA",
"GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"
}
eng2code = {"AUSTRIA": "AUT", "ENGLAND": "ENG", "FRANCE": "FRA", "GERMANY": "GER", "ITALY": "ITA", "RUSSIA": "RUS", "TURKEY": "TUR"}
norm = game.map.norm
out_lines = ["**ORDER HISTORY (Recent Rounds)**"]
@ -264,13 +246,6 @@ class GameHistory:
return "\n(No orders were issued in recent history)\n"
return "\n".join(out_lines)
def get_messages_this_round(self, power_name: str, current_phase_name: str) -> str:
current_phase: Optional[Phase] = None
for phase_obj in self.phases:
@ -325,11 +300,7 @@ class GameHistory:
if msg.recipient == power_name or (msg.recipient == "GLOBAL" and msg.sender != power_name):
# Skip if sender is this power (don't need to respond to own messages)
if msg.sender != power_name:
messages_to_power.append({
'sender': msg.sender,
'content': msg.content,
'phase': phase.name
})
messages_to_power.append({"sender": msg.sender, "content": msg.content, "phase": phase.name})
# Add debug logging
logger.info(f"Found {len(messages_to_power)} messages to {power_name} across {len(recent_phases)} phases")
@ -361,10 +332,10 @@ class GameHistory:
for msg in phase.messages:
# Handle both Message objects and dict objects
if isinstance(msg, Message):
if msg.sender == sender_name and msg.recipient not in ['GLOBAL', 'ALL']:
if msg.sender == sender_name and msg.recipient not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
else: # Assume dict
if msg['sender'] == sender_name and msg['recipient'] not in ['GLOBAL', 'ALL']:
if msg["sender"] == sender_name and msg["recipient"] not in ["GLOBAL", "ALL"]:
sender_messages.append(msg)
# Check for responses in this and next phases
@ -374,8 +345,8 @@ class GameHistory:
recipient = msg.recipient
msg_content = msg.content
else:
recipient = msg['recipient']
msg_content = msg['content']
recipient = msg["recipient"]
msg_content = msg["content"]
# Look for responses in current phase and next phases
found_response = False
@ -387,12 +358,14 @@ class GameHistory:
response_msgs = []
for m in check_phase.messages:
if isinstance(m, Message):
if m.sender == recipient and (m.recipient == sender_name or
(m.recipient in ['GLOBAL', 'ALL'] and sender_name in m.content)):
if m.sender == recipient and (
m.recipient == sender_name or (m.recipient in ["GLOBAL", "ALL"] and sender_name in m.content)
):
response_msgs.append(m)
else: # Assume dict
if m['sender'] == recipient and (m['recipient'] == sender_name or
(m['recipient'] in ['GLOBAL', 'ALL'] and sender_name in m.get('content', ''))):
if m["sender"] == recipient and (
m["recipient"] == sender_name or (m["recipient"] in ["GLOBAL", "ALL"] and sender_name in m.get("content", ""))
):
response_msgs.append(m)
if response_msgs:
@ -402,9 +375,6 @@ class GameHistory:
if not found_response:
if recipient not in ignored_by_power:
ignored_by_power[recipient] = []
ignored_by_power[recipient].append({
'phase': phase.name,
'content': msg_content
})
ignored_by_power[recipient].append({"phase": phase.name, "content": msg_content})
return ignored_by_power

View file

@ -3,7 +3,7 @@ import logging
import os
import json
import asyncio
from typing import Dict, List, Tuple, Optional, Any
from typing import Dict, Tuple, Optional, Any
from argparse import Namespace
from diplomacy import Game
@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
# --- Serialization / Deserialization ---
def serialize_agent(agent: DiplomacyAgent) -> dict:
"""Converts an agent object to a JSON-serializable dictionary."""
return {
@ -31,12 +32,8 @@ def serialize_agent(agent: DiplomacyAgent) -> dict:
"private_diary": agent.private_diary,
}
def deserialize_agent(
agent_data: dict,
prompts_dir: Optional[str] = None,
*,
override_model_id: Optional[str] = None
) -> DiplomacyAgent:
def deserialize_agent(agent_data: dict, prompts_dir: Optional[str] = None, *, override_model_id: Optional[str] = None) -> DiplomacyAgent:
"""
Recreates an agent object from a dictionary.
@ -70,6 +67,7 @@ def deserialize_agent(
# game_logic.py
_PHASE_ORDER = ["M", "R", "A"] # Movement → Retreats → Adjustments
def _next_phase_name(short: str) -> str:
"""
Return the Diplomacy phase string that chronologically follows *short*.
@ -89,13 +87,9 @@ def _next_phase_name(short: str) -> str:
else: # winter→ spring, next year
return f"S{year + 1}M"
def save_game_state(
game: Game,
agents: Dict[str, DiplomacyAgent],
game_history: GameHistory,
output_path: str,
run_config: Namespace,
completed_phase_name: str
game: Game, agents: Dict[str, DiplomacyAgent], game_history: GameHistory, output_path: str, run_config: Namespace, completed_phase_name: str
):
"""
Serialise the entire game to JSON, preserving per-phase custom metadata
@ -130,9 +124,7 @@ def save_game_state(
if extras:
previous_phase_extras[phase["name"]] = extras
except Exception as exc:
logger.warning(
"Could not load previous save to retain metadata: %s", exc, exc_info=True
)
logger.warning("Could not load previous save to retain metadata: %s", exc, exc_info=True)
# -------------------------------------------------------------- #
# 2. Build the fresh base structure from the diplomacy library. #
@ -143,11 +135,7 @@ def save_game_state(
# 3. Walk every phase and merge the metadata back in. #
# -------------------------------------------------------------- #
# Capture the *current* snapshot of every live agent exactly once.
current_state_agents = {
p_name: serialize_agent(p_agent)
for p_name, p_agent in agents.items()
if not game.powers[p_name].is_eliminated()
}
current_state_agents = {p_name: serialize_agent(p_agent) for p_name, p_agent in agents.items() if not game.powers[p_name].is_eliminated()}
for phase_block in saved_game.get("phases", []):
if int(phase_block["name"][1:5]) > run_config.max_year:
@ -167,19 +155,13 @@ def save_game_state(
# Plans for this phase may be empty in non-movement phases.
phase_obj = game_history._get_phase(phase_name)
phase_block["state_history_plans"] = (
phase_obj.plans if phase_obj else {}
)
phase_block["state_history_plans"] = phase_obj.plans if phase_obj else {}
# -------------------------------------------------------------- #
# 4. Attach top-level metadata and write atomically. #
# -------------------------------------------------------------- #
saved_game["phase_summaries"] = getattr(game, "phase_summaries", {})
saved_game["final_agent_states"] = {
p_name: {"relationships": a.relationships, "goals": a.goals}
for p_name, a in agents.items()
}
saved_game["final_agent_states"] = {p_name: {"relationships": a.relationships, "goals": a.goals} for p_name, a in agents.items()}
# Filter out phases > max_year
# saved_game["phases"] = [
@ -191,7 +173,6 @@ def save_game_state(
logger.info("Game state saved successfully.")
def load_game_state(
run_dir: str,
game_file_name: str,
@ -204,7 +185,7 @@ def load_game_state(
raise FileNotFoundError(f"Cannot resume. Save file not found at: {game_file_path}")
logger.info(f"Loading game state from: {game_file_path}")
with open(game_file_path, 'r') as f:
with open(game_file_path, "r") as f:
saved_game_data = json.load(f)
# If resuming, find the specified phase and truncate the data after it
@ -213,37 +194,37 @@ def load_game_state(
try:
# Find the index of the phase *before* the one we want to resume from.
# We will start the simulation *at* the resume_from_phase.
resume_idx = next(i for i, phase in enumerate(saved_game_data['phases']) if phase['name'] == resume_from_phase)
resume_idx = next(i for i, phase in enumerate(saved_game_data["phases"]) if phase["name"] == resume_from_phase)
# Truncate the list to exclude everything after the resume phase
# Note: the state saved for a given phase represents the state at the beginning of that phase.
saved_game_data['phases'] = saved_game_data['phases'][:resume_idx+1]
saved_game_data["phases"] = saved_game_data["phases"][: resume_idx + 1]
# Wipe any data that must be regenerated.
for key in ("orders", "results", "messages"):
saved_game_data['phases'][-1].pop(key, None)
saved_game_data["phases"][-1].pop(key, None)
logger.info(f"Game history truncated to {len(saved_game_data['phases'])} phases. The next phase to run will be {resume_from_phase}.")
except StopIteration:
# If the phase is not found, maybe it's the first phase (S1901M)
if resume_from_phase == "S1901M":
saved_game_data['phases'] = []
saved_game_data["phases"] = []
logger.info("Resuming from S1901M. Starting with a clean history.")
else:
raise ValueError(f"Resume phase '{resume_from_phase}' not found in the save file.")
# Reconstruct the Game object
last_phase = saved_game_data['phases'][-1]
last_phase = saved_game_data["phases"][-1]
# Wipe the data that must be regenerated **but preserve the keys**
last_phase['orders'] = {} # was dict
last_phase['results'] = {} # was dict
last_phase['messages'] = []
last_phase["orders"] = {} # was dict
last_phase["results"] = {} # was dict
last_phase["messages"] = []
game = from_saved_game_format(saved_game_data)
game.phase_summaries = saved_game_data.get('phase_summaries', {})
game.phase_summaries = saved_game_data.get("phase_summaries", {})
# Reconstruct agents and game history from the *last* valid phase in the data
if not saved_game_data['phases']:
if not saved_game_data["phases"]:
# This happens if we are resuming from the very beginning (S1901M)
logger.info("No previous phases found. Initializing fresh agents and history.")
agents = {} # Will be created by the main loop
@ -251,10 +232,10 @@ def load_game_state(
else:
# We save the game state up to & including the current (uncompleted) phase.
# So we need to grab the agent state from the previous (completed) phase.
if len(saved_game_data['phases']) <= 1:
if len(saved_game_data["phases"]) <= 1:
last_phase_data = {}
else:
last_phase_data = saved_game_data['phases'][-2]
last_phase_data = saved_game_data["phases"][-2]
# -------------------- Rebuild agents -------------------- #
agents = {}
@ -269,18 +250,11 @@ def load_game_state(
elif len(provided) == 1:
power_model_map = dict(zip(powers_order, provided * len(powers_order)))
else:
raise ValueError(
f"Invalid --models argument: expected 1 or {len(powers_order)} "
f"items, got {len(provided)}."
)
raise ValueError(f"Invalid --models argument: expected 1 or {len(powers_order)} items, got {len(provided)}.")
if "state_agents" in last_phase_data:
logger.info("Rebuilding agents from saved state...")
prompts_dir_from_config = (
run_config.prompts_dir
if run_config and hasattr(run_config, "prompts_dir")
else None
)
prompts_dir_from_config = run_config.prompts_dir if run_config and hasattr(run_config, "prompts_dir") else None
for power_name, agent_data in last_phase_data["state_agents"].items():
override_id = power_model_map.get(power_name)
agents[power_name] = deserialize_agent(
@ -290,37 +264,27 @@ def load_game_state(
)
logger.info(f"Rebuilt {len(agents)} agents.")
else:
raise ValueError(
"Cannot resume: 'state_agents' key not found in the last phase "
"of the save file."
)
raise ValueError("Cannot resume: 'state_agents' key not found in the last phase of the save file.")
# Rebuild GameHistory
game_history = GameHistory()
logger.info("Rebuilding game history...")
for phase_data in saved_game_data['phases'][:-1]:
phase_name = phase_data['name']
for phase_data in saved_game_data["phases"][:-1]:
phase_name = phase_data["name"]
game_history.add_phase(phase_name)
# Add messages
for msg in phase_data.get('messages', []):
game_history.add_message(phase_name, msg['sender'], msg['recipient'], msg['message'])
for msg in phase_data.get("messages", []):
game_history.add_message(phase_name, msg["sender"], msg["recipient"], msg["message"])
# Add plans
if 'state_history_plans' in phase_data:
for p_name, plan in phase_data['state_history_plans'].items():
if "state_history_plans" in phase_data:
for p_name, plan in phase_data["state_history_plans"].items():
game_history.add_plan(phase_name, p_name, plan)
logger.info("Game history rebuilt.")
return game, agents, game_history, run_config
async def initialize_new_game(
args: Namespace,
game: Game,
game_history: GameHistory,
llm_log_file_path: str
) -> Dict[str, DiplomacyAgent]:
async def initialize_new_game(args: Namespace, game: Game, game_history: GameHistory, llm_log_file_path: str) -> Dict[str, DiplomacyAgent]:
"""Initializes agents for a new game."""
powers_order = sorted(list(ALL_POWERS))

View file

@ -1,7 +1,6 @@
# ai_diplomacy/initialization.py
import logging
import json
import os
from typing import Optional
from ..config import config
@ -18,10 +17,11 @@ from .formatter import format_with_gemini_flash, FORMAT_INITIAL_STATE
logger = logging.getLogger(__name__)
async def initialize_agent_state_ext(
agent: 'DiplomacyAgent',
game: 'Game',
game_history: 'GameHistory',
agent: "DiplomacyAgent",
game: "Game",
game_history: "GameHistory",
log_file_path: str,
prompts_dir: Optional[str] = None,
):
@ -37,21 +37,20 @@ async def initialize_agent_state_ext(
try:
# Load the prompt template
allowed_labels_str = ", ".join(ALLOWED_RELATIONSHIPS)
initial_prompt_template = _load_prompt_file(get_prompt_path('initial_state_prompt.txt'), prompts_dir=prompts_dir)
initial_prompt_template = _load_prompt_file(get_prompt_path("initial_state_prompt.txt"), prompts_dir=prompts_dir)
if not initial_prompt_template:
logger.error(f"[{power_name}] Could not load {get_prompt_path('initial_state_prompt.txt')}. Cannot initialize.")
return
# Format the prompt with variables
initial_prompt = initial_prompt_template.format(
power_name=power_name,
allowed_labels_str=allowed_labels_str
)
initial_prompt = initial_prompt_template.format(power_name=power_name, allowed_labels_str=allowed_labels_str)
board_state = game.get_state() if game else {}
possible_orders = game.get_all_possible_orders() if game else {}
logger.debug(f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}")
logger.debug(
f"[{power_name}] Preparing context for initial state. Board state type: {type(board_state)}, possible_orders type: {type(possible_orders)}, game_history type: {type(game_history)}"
)
# Ensure agent.client and its methods can handle None for game/board_state/etc. if that's a possibility
# For initialization, game should always be present.
@ -76,7 +75,7 @@ async def initialize_agent_state_ext(
log_file_path=log_file_path,
power_name=power_name,
phase=current_phase,
response_type='initialization', # Context for run_llm_and_log internal error logging
response_type="initialization", # Context for run_llm_and_log internal error logging
)
logger.debug(f"[{power_name}] LLM response for initial state: {response[:300]}...") # Log a snippet
@ -86,11 +85,7 @@ async def initialize_agent_state_ext(
if config.USE_UNFORMATTED_PROMPTS:
# Format the natural language response into JSON
formatted_response = await format_with_gemini_flash(
response,
FORMAT_INITIAL_STATE,
power_name=power_name,
phase=current_phase,
log_file_path=log_file_path
response, FORMAT_INITIAL_STATE, power_name=power_name, phase=current_phase, log_file_path=log_file_path
)
else:
# Use the raw response directly (already formatted)
@ -108,12 +103,16 @@ async def initialize_agent_state_ext(
# Defensive check for update_data type if parsing was initially considered successful
if parsed_successfully:
if isinstance(update_data, str):
logger.error(f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}...")
logger.error(
f"[{power_name}] _extract_json_from_text returned a string, not a dict/list, despite not raising an exception. This indicates an unexpected parsing issue. String returned: {update_data[:300]}..."
)
update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: ParsedAsStr"
elif not isinstance(update_data, dict): # Expecting a dict from JSON object
logger.error(f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}")
logger.error(
f"[{power_name}] _extract_json_from_text returned a non-dict type ({type(update_data)}), expected dict. Data: {str(update_data)[:300]}"
)
update_data = {} # Treat as parsing failure
parsed_successfully = False
success_status = "Failure: NotADict"
@ -122,8 +121,8 @@ async def initialize_agent_state_ext(
initial_relationships_applied = False
if parsed_successfully:
initial_goals = update_data.get('initial_goals') or update_data.get('goals')
initial_relationships = update_data.get('initial_relationships') or update_data.get('relationships')
initial_goals = update_data.get("initial_goals") or update_data.get("goals")
initial_relationships = update_data.get("initial_relationships") or update_data.get("relationships")
if isinstance(initial_goals, list) and initial_goals:
agent.goals = initial_goals
@ -200,7 +199,7 @@ async def initialize_agent_state_ext(
response_type="initial_state_setup", # Specific type for CSV logging
raw_input_prompt=full_prompt,
raw_response=response,
success=success_status
success=success_status,
)
# Final log of state after initialization attempt

View file

@ -209,4 +209,3 @@ def _patched_generate_phase_summary(self: Game, phase_key, summary_callback=None
Game._generate_phase_summary = _patched_generate_phase_summary # type: ignore[assignment]
LOGGER.info("Game._generate_phase_summary patched with narrative generation.")

View file

@ -6,8 +6,7 @@ from typing import Dict, TYPE_CHECKING
from diplomacy.engine.message import Message, GLOBAL
from .agent import DiplomacyAgent
from .clients import load_model_client
from .utils import gather_possible_orders, load_prompt, normalize_recipient_name
from .utils import gather_possible_orders, normalize_recipient_name
if TYPE_CHECKING:
from .game_history import GameHistory
@ -21,9 +20,9 @@ load_dotenv()
async def conduct_negotiations(
game: 'Game',
game: "Game",
agents: Dict[str, DiplomacyAgent],
game_history: 'GameHistory',
game_history: "GameHistory",
model_error_stats: Dict[str, Dict[str, int]],
log_file_path: str,
max_rounds: int = 3,
@ -35,12 +34,8 @@ async def conduct_negotiations(
"""
logger.info("Starting negotiation phase.")
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
logger.info(f"Active powers for negotiations: {active_powers}")
if eliminated_powers:

View file

@ -1,14 +1,13 @@
from dotenv import load_dotenv
import logging
import concurrent.futures
from typing import Dict
from .clients import load_model_client
from .game_history import GameHistory
from .agent import DiplomacyAgent
logger = logging.getLogger(__name__)
async def planning_phase(
game,
agents: Dict[str, DiplomacyAgent],
@ -20,12 +19,8 @@ async def planning_phase(
Lets each power generate a strategic plan using their DiplomacyAgent.
"""
logger.info(f"Starting planning phase for {game.current_short_phase}...")
active_powers = [
p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()
]
eliminated_powers = [
p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()
]
active_powers = [p_name for p_name, p_obj in game.powers.items() if not p_obj.is_eliminated()]
eliminated_powers = [p_name for p_name, p_obj in game.powers.items() if p_obj.is_eliminated()]
logger.info(f"Active powers for planning: {active_powers}")
if eliminated_powers:
@ -35,9 +30,7 @@ async def planning_phase(
board_state = game.get_state()
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(active_powers)
) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_powers)) as executor:
futures = {}
for power_name in active_powers:
if power_name not in agents:
@ -71,16 +64,14 @@ async def planning_phase(
if plan_result.startswith("Error:"):
logger.warning(f"Agent {power_name} reported an error during planning: {plan_result}")
if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_generation_errors', 0)
model_error_stats[power_name]['planning_generation_errors'] += 1
model_error_stats[power_name].setdefault("planning_generation_errors", 0)
model_error_stats[power_name]["planning_generation_errors"] += 1
else:
model_error_stats.setdefault(f'{power_name}_planning_generation_errors', 0)
model_error_stats[f'{power_name}_planning_generation_errors'] += 1
model_error_stats.setdefault(f"{power_name}_planning_generation_errors", 0)
model_error_stats[f"{power_name}_planning_generation_errors"] += 1
elif plan_result:
agent.add_journal_entry(f"Generated plan for {game.current_short_phase}: {plan_result[:100]}...")
game_history.add_plan(
game.current_short_phase, power_name, plan_result
)
game_history.add_plan(game.current_short_phase, power_name, plan_result)
logger.debug(f"Added plan for {power_name} to history.")
else:
logger.warning(f"Agent {power_name} returned an empty plan.")
@ -88,11 +79,11 @@ async def planning_phase(
except Exception as e:
logger.error(f"Exception during planning result processing for {power_name}: {e}")
if power_name in model_error_stats:
model_error_stats[power_name].setdefault('planning_execution_errors', 0)
model_error_stats[power_name]['planning_execution_errors'] += 1
model_error_stats[power_name].setdefault("planning_execution_errors", 0)
model_error_stats[power_name]["planning_execution_errors"] += 1
else:
model_error_stats.setdefault(f'{power_name}_planning_execution_errors', 0)
model_error_stats[f'{power_name}_planning_execution_errors'] += 1
model_error_stats.setdefault(f"{power_name}_planning_execution_errors", 0)
model_error_stats[f"{power_name}_planning_execution_errors"] += 1
logger.info("Planning phase processing complete.")
return game_history

View file

@ -14,6 +14,7 @@ import re
logger = logging.getLogger(__name__)
def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
"""
Return graph[PROV]['ARMY'|'FLEET'] = list of 3-letter neighbour provinces.
@ -27,9 +28,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
if len(loc.split("/")[0]) == 3
}
graph: Dict[str, Dict[str, List[str]]] = {
p: {"ARMY": [], "FLEET": []} for p in provs
}
graph: Dict[str, Dict[str, List[str]]] = {p: {"ARMY": [], "FLEET": []} for p in provs}
# ── helper: list every concrete variant of a province ─────────────────
def variants(code: str) -> List[str]:
@ -55,11 +54,7 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
graph[src]["ARMY"].append(dest)
# FLEETany src variant that can host a fleet is valid
if any(
game_map.abuts("F", sv, "-", dv)
for sv in src_vers
for dv in dest_vers
):
if any(game_map.abuts("F", sv, "-", dv) for sv in src_vers for dv in dest_vers):
graph[src]["FLEET"].append(dest)
# ── tidy up duplicates / order ---------------------------------------
@ -70,24 +65,23 @@ def build_diplomacy_graph(game_map: GameMap) -> Dict[str, Dict[str, List[str]]]:
return graph
def bfs_shortest_path(
graph: Dict[str, Dict[str, List[str]]],
board_state: BoardState,
game_map: GameMap, # Added game_map
start_loc_full: str, # This is a FULL location name like 'VIE' or 'STP/SC'
unit_type: str,
is_target_func: Callable[[str, BoardState], bool] # Expects SHORT name for loc
is_target_func: Callable[[str, BoardState], bool], # Expects SHORT name for loc
) -> Optional[List[str]]: # Returns path of SHORT names
"""Performs BFS to find the shortest path from start_loc to a target satisfying is_target_func."""
# Convert full start location to short province name
start_loc_short = game_map.loc_name.get(start_loc_full, start_loc_full)
if '/' in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
if "/" in start_loc_short: # If it was STP/SC, loc_name gives STP. If it was VIE, loc_name gives VIE.
start_loc_short = start_loc_short[:3]
# If start_loc_full was already short (e.g. 'VIE'), get might return it as is, or its value if it was a key.
# A simpler way for non-coastal full (like 'VIE') or already short:
if '/' not in start_loc_full:
if "/" not in start_loc_full:
start_loc_short = start_loc_full[:3] # Ensures 'VIE' -> 'VIE', 'PAR' -> 'PAR'
else: # Has '/', e.g. 'STP/SC'
start_loc_short = start_loc_full[:3] # 'STP/SC' -> 'STP'
@ -119,10 +113,11 @@ def bfs_shortest_path(
queue.append((next_loc_short, new_path))
return None
# --- Helper functions for context generation ---
def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str]:
"""Returns the full unit string (e.g., 'A PAR (FRA)') if a unit is at the location, else None."""
for power, unit_list in board_state.get('units', {}).items():
for power, unit_list in board_state.get("units", {}).items():
for unit_str in unit_list: # e.g., "A PAR", "F STP/SC"
parts = unit_str.split(" ")
if len(parts) == 2:
@ -131,24 +126,26 @@ def get_unit_at_location(board_state: BoardState, location: str) -> Optional[str
return f"{parts[0]} {location} ({power})"
return None
def get_sc_controller(game_map: GameMap, board_state: BoardState, location: str) -> Optional[str]:
"""Returns the controlling power's name if the location is an SC, else None."""
# Normalize location to base province name, as SCs are tied to provinces, not specific coasts
loc_province_name = game_map.loc_name.get(location, location).upper()[:3]
if loc_province_name not in game_map.scs:
return None
for power, sc_list in board_state.get('centers', {}).items():
for power, sc_list in board_state.get("centers", {}).items():
if loc_province_name in sc_list:
return power
return None # Unowned SC
def get_shortest_path_to_friendly_unit(
board_state: BoardState,
graph: Dict[str, Dict[str, List[str]]],
game_map: GameMap, # Added game_map
power_name: str,
start_unit_loc_full: str,
start_unit_type: str
start_unit_type: str,
) -> Optional[Tuple[str, List[str]]]:
"""Finds the shortest path to any friendly unit of the same power."""
@ -183,7 +180,7 @@ def get_nearest_enemy_units(
power_name: str,
start_unit_loc_full: str,
start_unit_type: str,
n: int = 3
n: int = 3,
) -> List[Tuple[str, List[str]]]:
"""Finds up to N nearest enemy units, sorted by path length."""
enemy_paths: List[Tuple[str, List[str]]] = [] # (enemy_unit_str, path_short_names)
@ -205,9 +202,9 @@ def get_nearest_enemy_units(
for target_enemy_loc_full, enemy_unit_str in all_enemy_unit_locations_full:
target_enemy_loc_short = game_map.loc_name.get(target_enemy_loc_full, target_enemy_loc_full)
if '/' in target_enemy_loc_short:
if "/" in target_enemy_loc_short:
target_enemy_loc_short = target_enemy_loc_short[:3]
if '/' not in target_enemy_loc_full:
if "/" not in target_enemy_loc_full:
target_enemy_loc_short = target_enemy_loc_full[:3]
else:
target_enemy_loc_short = target_enemy_loc_full[:3]
@ -281,7 +278,7 @@ def get_adjacent_territory_details(
board_state: BoardState,
unit_loc_full: str, # The location of the unit whose adjacencies we're checking
unit_type: str, # ARMY or FLEET of the unit at unit_loc_full
graph: Dict[str, Dict[str, List[str]]]
graph: Dict[str, Dict[str, List[str]]],
) -> str:
"""Generates a string describing adjacent territories and units that can interact with them."""
output_lines: List[str] = []
@ -291,9 +288,9 @@ def get_adjacent_territory_details(
# For fleets, graph[unit_loc_full]['FLEET'] gives full loc names (incl coasts)
# THIS COMMENT IS NOW OUTDATED. Graph uses short names for keys and values.
unit_loc_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)
if '/' in unit_loc_short:
if "/" in unit_loc_short:
unit_loc_short = unit_loc_short[:3]
if '/' not in unit_loc_full:
if "/" not in unit_loc_full:
unit_loc_short = unit_loc_full[:3]
else:
unit_loc_short = unit_loc_full[:3]
@ -308,11 +305,11 @@ def get_adjacent_territory_details(
continue
processed_adj_provinces.add(adj_loc_short)
adj_loc_type = game_map.loc_type.get(adj_loc_short, 'UNKNOWN').upper()
if adj_loc_type == 'COAST' or adj_loc_type == 'LAND':
adj_loc_type_display = 'LAND' if adj_loc_type == 'LAND' else 'COAST'
elif adj_loc_type == 'WATER':
adj_loc_type_display = 'WATER'
adj_loc_type = game_map.loc_type.get(adj_loc_short, "UNKNOWN").upper()
if adj_loc_type == "COAST" or adj_loc_type == "LAND":
adj_loc_type_display = "LAND" if adj_loc_type == "LAND" else "COAST"
elif adj_loc_type == "WATER":
adj_loc_type_display = "WATER"
else: # SHUT etc.
adj_loc_type_display = adj_loc_type
@ -330,8 +327,7 @@ def get_adjacent_territory_details(
# "Can support/move to" - Simplified: list units in *further* adjacent provinces
# A true "can support/move to" would require checking possible orders of those further units.
# further_adj_provinces are short names from the graph
further_adj_provinces_short = graph.get(adj_loc_short, {}).get('ARMY', []) + \
graph.get(adj_loc_short, {}).get('FLEET', [])
further_adj_provinces_short = graph.get(adj_loc_short, {}).get("ARMY", []) + graph.get(adj_loc_short, {}).get("FLEET", [])
supporting_units_info = []
processed_further_provinces = set()
@ -388,12 +384,12 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
continue
unit_type_char = unit_str_full.split(" ")[0] # 'A' or 'F'
unit_type_long = "ARMY" if unit_type_char == 'A' else "FLEET"
unit_type_long = "ARMY" if unit_type_char == "A" else "FLEET"
loc_province_short = game_map.loc_name.get(unit_loc_full, unit_loc_full).upper()[:3]
loc_type_short = game_map.loc_type.get(loc_province_short, "UNKNOWN").upper()
if loc_type_short == 'COAST' or loc_type_short == 'LAND':
loc_type_display = 'LAND' if loc_type_short == 'LAND' else 'COAST'
if loc_type_short == "COAST" or loc_type_short == "LAND":
loc_type_display = "LAND" if loc_type_short == "LAND" else "COAST"
else:
loc_type_display = loc_type_short
@ -401,7 +397,7 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
current_unit_lines.append(f' <UnitContext loc="{unit_loc_full}">')
# Unit Information section
current_unit_lines.append(' <UnitInformation>')
current_unit_lines.append(" <UnitInformation>")
sc_owner_at_loc = get_sc_controller(game_map, board_state, unit_loc_full)
header_content = f"Strategic territory held by {power_name}: {unit_loc_full} ({loc_type_display})"
if sc_owner_at_loc == power_name:
@ -410,55 +406,61 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
header_content += f" (SC controlled by {sc_owner_at_loc})"
current_unit_lines.append(f" {header_content}")
current_unit_lines.append(f" Units present: {unit_str_full}")
current_unit_lines.append(' </UnitInformation>')
current_unit_lines.append(" </UnitInformation>")
# Possible moves section
current_unit_lines.append(' <PossibleMoves>')
current_unit_lines.append(" <PossibleMoves>")
current_unit_lines.append(" Possible moves:")
for order_str in unit_specific_possible_orders:
current_unit_lines.append(f" {order_str}")
current_unit_lines.append(' </PossibleMoves>')
current_unit_lines.append(" </PossibleMoves>")
# Nearest enemy units section
enemy_units_info = get_nearest_enemy_units(board_state, graph, game_map, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestEnemyUnits>')
current_unit_lines.append(" <NearestEnemyUnits>")
if enemy_units_info:
current_unit_lines.append(" Nearest units (not ours):")
for enemy_unit_str, enemy_path_short in enemy_units_info:
current_unit_lines.append(f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]")
current_unit_lines.append(
f" {enemy_unit_str}, path=[{unit_loc_full}{(''.join(enemy_path_short[1:])) if len(enemy_path_short) > 1 else enemy_path_short[0]}]"
)
else:
current_unit_lines.append(" Nearest units (not ours): None found")
current_unit_lines.append(' </NearestEnemyUnits>')
current_unit_lines.append(" </NearestEnemyUnits>")
# Nearest supply centers (not controlled by us) section
uncontrolled_scs_info = get_nearest_uncontrolled_scs(game_map, board_state, graph, power_name, unit_loc_full, unit_type_long, n=3)
current_unit_lines.append(' <NearestUncontrolledSupplyCenters>')
current_unit_lines.append(" <NearestUncontrolledSupplyCenters>")
if uncontrolled_scs_info:
current_unit_lines.append(" Nearest supply centers (not controlled by us):")
for sc_str, dist, sc_path_short in uncontrolled_scs_info:
current_unit_lines.append(f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]")
current_unit_lines.append(
f" {sc_str}, dist={dist}, path=[{unit_loc_full}{(''.join(sc_path_short[1:])) if len(sc_path_short) > 1 else sc_path_short[0]}]"
)
else:
current_unit_lines.append(" Nearest supply centers (not controlled by us): None found")
current_unit_lines.append(' </NearestUncontrolledSupplyCenters>')
current_unit_lines.append(" </NearestUncontrolledSupplyCenters>")
# Adjacent territories details section
adj_details_str = get_adjacent_territory_details(game_map, board_state, unit_loc_full, unit_type_long, graph)
current_unit_lines.append(' <AdjacentTerritories>')
current_unit_lines.append(" <AdjacentTerritories>")
if adj_details_str:
current_unit_lines.append(" Adjacent territories (including units that can support/move to the adjacent territory):")
# Assuming adj_details_str is already formatted with newlines and indentation for its content
# We might need to indent adj_details_str if it's a single block of text
# For now, let's add a standard indent to each line of adj_details_str if it contains newlines
if '\n' in adj_details_str:
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split('\n')])
if "\n" in adj_details_str:
indented_adj_details = "\n".join([f" {line}" for line in adj_details_str.split("\n")])
current_unit_lines.append(indented_adj_details)
else:
current_unit_lines.append(f" {adj_details_str}")
else:
current_unit_lines.append(" Adjacent territories: None relevant or all are empty/uncontested by direct threats.") # Added more descriptive else
current_unit_lines.append(' </AdjacentTerritories>')
current_unit_lines.append(
" Adjacent territories: None relevant or all are empty/uncontested by direct threats."
) # Added more descriptive else
current_unit_lines.append(" </AdjacentTerritories>")
current_unit_lines.append(' </UnitContext>')
current_unit_lines.append(" </UnitContext>")
final_context_lines.extend(current_unit_lines)
final_context_lines.append("</PossibleOrdersContext>")
@ -469,7 +471,6 @@ def generate_rich_order_context_xml(game: Any, power_name: str, possible_orders_
# Regex and tiny helpers
# ---------------------------------------------------------------------------
import re
from typing import Tuple, List, Dict, Optional, Any
# ── order-syntax matchers ─────────────────────────────────────────────────
@ -478,6 +479,7 @@ _HOLD_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? H$") # NEW
_RETREAT_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? R [A-Z]{3}(?:/[A-Z]{2})?$")
_ADJUST_RE = re.compile(r"^[AF] [A-Z]{3}(?:/[A-Z]{2})? [BD]$") # build / disband
def _is_hold_order(order: str) -> bool: # NEW
return bool(_HOLD_RE.match(order.strip()))
@ -486,15 +488,19 @@ def _norm_power(name: str) -> str:
"""Trim & uppercase for reliable comparisons."""
return name.strip().upper()
def _is_simple_move(order: str) -> bool:
return bool(_SIMPLE_MOVE_RE.match(order.strip()))
def _is_retreat_order(order: str) -> bool:
return bool(_RETREAT_RE.match(order.strip()))
def _is_adjust_order(order: str) -> bool:
return bool(_ADJUST_RE.match(order.strip()))
def _split_move(order: str) -> Tuple[str, str]:
"""Return ('A BUD', 'TRI') from 'A BUD - TRI' (validated move only)."""
unit_part, dest = order.split(" - ")
@ -505,6 +511,7 @@ def _split_move(order: str) -> Tuple[str, str]:
# Gather *all* friendly support orders for a given move
# ---------------------------------------------------------------------------
def _all_support_examples(
mover: str,
dest: str,
@ -526,6 +533,7 @@ def _all_support_examples(
return supports
def _all_support_hold_examples(
holder: str,
all_orders: Dict[str, List[str]],
@ -550,6 +558,7 @@ def _all_support_hold_examples(
# Province-type resolver (handles short codes, coasts, seas)
# ---------------------------------------------------------------------------
def _province_type_display(game_map, prov_short: str) -> str:
"""
Return 'LAND', 'COAST', or 'WATER' for the 3-letter province code.
@ -593,6 +602,7 @@ def _dest_occupancy_desc(
# Adjacent-territory lines (used by movement-phase builder)
# ---------------------------------------------------------------------------
def _adjacent_territory_lines(
graph,
game_map,
@ -633,13 +643,9 @@ def _adjacent_territory_lines(
pwr = unit_here.split(" ")[-1].strip("()")
if pwr == our_power:
friend_descr = unit_here.split(" (")[0]
lines.append(
f"{indent2}Support hold: {mover_descr} S {friend_descr}"
)
lines.append(f"{indent2}Support hold: {mover_descr} S {friend_descr}")
else:
lines.append(
f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa"
)
lines.append(f"{indent2}-> {unit_here} can support or contest {mover_descr}s moves and vice-versa")
return lines
@ -648,6 +654,7 @@ def _adjacent_territory_lines(
# Movement-phase generator (UNCHANGED LOGIC)
# ---------------------------------------------------------------------------
def _generate_rich_order_context_movement(
game: Any,
power_name: str,
@ -672,18 +679,14 @@ def _generate_rich_order_context_movement(
if _norm_power(unit_power) != me:
continue # Skip units that arent ours
mover_descr, _ = _split_move(
f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}"
)
mover_descr, _ = _split_move(f"{unit_full_str.split(' ')[0]} {unit_loc_full} - {unit_loc_full}")
prov_short = game_map.loc_name.get(unit_loc_full, unit_loc_full)[:3]
prov_type_disp = _province_type_display(game_map, prov_short)
sc_tag = " (SC)" if prov_short in game_map.scs else ""
owner = get_sc_controller(game_map, board_state, unit_loc_full) or "None"
owner_line = (
f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
)
owner_line = f"Held by {owner} (You)" if owner == power_name else f"Held by {owner}"
ind = " "
block: List[str] = [f"<Territory {prov_short}>"]
@ -693,18 +696,16 @@ def _generate_rich_order_context_movement(
# ----- Adjacent territories -----
block.append("# Adjacent territories:")
block.extend(
_adjacent_territory_lines(
graph, game_map, board_state,
unit_loc_full, mover_descr, power_name
)
)
block.extend(_adjacent_territory_lines(graph, game_map, board_state, unit_loc_full, mover_descr, power_name))
# ----- Nearest enemy units -----
block.append("# Nearest units (not ours):")
enemies = get_nearest_enemy_units(
board_state, graph, game_map,
power_name, unit_loc_full,
board_state,
graph,
game_map,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@ -715,8 +716,11 @@ def _generate_rich_order_context_movement(
# ----- Nearest uncontrolled SCs -----
block.append("# Nearest supply centers (not controlled by us):")
scs = get_nearest_uncontrolled_scs(
game_map, board_state, graph,
power_name, unit_loc_full,
game_map,
board_state,
graph,
power_name,
unit_loc_full,
"ARMY" if mover_descr.startswith("A") else "FLEET",
n=3,
)
@ -737,9 +741,7 @@ def _generate_rich_order_context_movement(
# ---- Moves (same behaviour as before) ----
for mv in simple_moves:
mover, dest = _split_move(mv)
occ = _dest_occupancy_desc(
dest.split("/")[0][:3], game_map, board_state, power_name
)
occ = _dest_occupancy_desc(dest.split("/")[0][:3], game_map, board_state, power_name)
block.append(f"{ind}{mv} {occ}")
for s in _all_support_examples(mover, dest, possible_orders_for_power):
@ -753,7 +755,6 @@ def _generate_rich_order_context_movement(
for s in _all_support_hold_examples(holder, possible_orders_for_power):
block.append(f"{ind * 2}Available Support: {s}")
block.append(f"</Territory {prov_short}>")
blocks.append("\n".join(block))
@ -764,6 +765,7 @@ def _generate_rich_order_context_movement(
# Retreat-phase builder echo orders verbatim, no tags
# ---------------------------------------------------------------------------
def _generate_rich_order_context_retreat(
game: Any,
power_name: str,
@ -784,11 +786,11 @@ def _generate_rich_order_context_retreat(
return "\n".join(lines) if lines else "(No dislodged units)"
# ---------------------------------------------------------------------------
# Adjustment-phase builder summary line + orders, no WAIVEs, no tags
# ---------------------------------------------------------------------------
def _generate_rich_order_context_adjustment(
game: Any,
power_name: str,
@ -824,13 +826,11 @@ def _generate_rich_order_context_adjustment(
return "\n".join(lines) if len(lines) > 1 else summary
# ---------------------------------------------------------------------------
# Phase-dispatch wrapper (public entry point)
# ---------------------------------------------------------------------------
def generate_rich_order_context(
game: Any,
power_name: str,
@ -847,21 +847,13 @@ def generate_rich_order_context(
phase_type = game.current_short_phase[-1]
if phase_type == "M": # Movement
return _generate_rich_order_context_movement(
game, power_name, possible_orders_for_power
)
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)
if phase_type == "R": # Retreat
return _generate_rich_order_context_retreat(
game, power_name, possible_orders_for_power
)
return _generate_rich_order_context_retreat(game, power_name, possible_orders_for_power)
if phase_type == "A": # Adjustment (build / disband)
return _generate_rich_order_context_adjustment(
game, power_name, possible_orders_for_power
)
return _generate_rich_order_context_adjustment(game, power_name, possible_orders_for_power)
# Fallback treat unknown formats as movement
return _generate_rich_order_context_movement(
game, power_name, possible_orders_for_power
)
return _generate_rich_order_context_movement(game, power_name, possible_orders_for_power)

View file

@ -11,7 +11,6 @@ from .possible_order_context import (
generate_rich_order_context,
generate_rich_order_context_xml,
)
import os
from .game_history import GameHistory # Assuming GameHistory is correctly importable
# placeholder for diplomacy.Game to avoid circular or direct dependency if not needed for typehinting only
@ -230,4 +229,3 @@ def construct_order_generation_prompt(
logger.debug(f"Final order generation prompt preview for {power_name}: {final_prompt[:500]}...")
return final_prompt

View file

@ -35,7 +35,7 @@ def atomic_write_json(data: dict, filepath: str):
# Write to a temporary file in the same directory
temp_filepath = f"{filepath}.tmp.{os.getpid()}"
with open(temp_filepath, 'w', encoding='utf-8') as f:
with open(temp_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
# Atomically rename the temporary file to the final destination
@ -110,11 +110,7 @@ def get_special_models() -> Dict[str, str]:
# Use a different Gemini model for formatting
export AI_DIPLOMACY_FORMATTER_MODEL="gemini-2.0-flash"
"""
return {
"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL,
"formatter": config.AI_DIPLOMACY_FORMATTER_MODEL
}
return {"phase_summary": config.AI_DIPLOMACY_NARRATIVE_MODEL, "formatter": config.AI_DIPLOMACY_FORMATTER_MODEL}
def gather_possible_orders(game: Game, power_name: str) -> Dict[str, List[str]]:
@ -175,15 +171,19 @@ async def get_valid_orders(
# ── 2. Type check ──────────────────────────────────────────
if not isinstance(raw_orders, list):
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.",
power_name, raw_orders)
logger.warning("[%s] Orders received from LLM are not a list: %s. Using fallback.", power_name, raw_orders)
model_error_stats[client.model_name]["order_decoding_errors"] += 1
return {"valid": client.fallback_orders(possible_orders), "invalid": []}
# ── 3. Round-trip validation with engine ───────────────────
CODE_TO_ENGINE = {
"AUT": "AUSTRIA", "ENG": "ENGLAND", "FRA": "FRANCE",
"GER": "GERMANY", "ITA": "ITALY", "RUS": "RUSSIA", "TUR": "TURKEY",
"AUT": "AUSTRIA",
"ENG": "ENGLAND",
"FRA": "FRANCE",
"GER": "GERMANY",
"ITA": "ITALY",
"RUS": "RUSSIA",
"TUR": "TURKEY",
}
engine_power = power_name if power_name in game.powers else CODE_TO_ENGINE[power_name]
@ -226,8 +226,6 @@ async def get_valid_orders(
return {"valid": valid, "invalid": invalid}
def normalize_and_compare_orders(
issued_orders: Dict[str, List[str]],
accepted_orders_dict: Dict[str, List[str]],
@ -348,8 +346,6 @@ def load_prompt(filename: str, prompts_dir: Optional[str] = None) -> str:
return ""
# == New LLM Response Logging Function ==
def log_llm_response(
log_file_path: str,
@ -374,14 +370,18 @@ def log_llm_response(
with open(log_file_path, "a", newline="", encoding="utf-8") as csvfile:
# Added "raw_input" to fieldnames
fieldnames = ["model", "power", "phase", "response_type", "raw_input", "raw_response", "success"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,
writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=csv.QUOTE_ALL, # Quote all fields to handle commas and newlines
escapechar='\\') # Use backslash for escaping
escapechar="\\",
) # Use backslash for escaping
if not file_exists:
writer.writeheader() # Write header only if file is new
writer.writerow({
writer.writerow(
{
"model": model_name,
"power": power_name if power_name else "game", # Use 'game' if no specific power
"phase": phase,
@ -389,12 +389,12 @@ def log_llm_response(
"raw_input": raw_input_prompt, # Added raw_input to the row
"raw_response": raw_response,
"success": success,
})
}
)
except Exception as e:
logger.error(f"Failed to log LLM response to {log_file_path}: {e}", exc_info=True)
async def run_llm_and_log(
client: "BaseModelClient",
prompt: str,
@ -428,8 +428,7 @@ async def run_llm_and_log(
except Exception as e:
if attempt == attempts - 1:
logger.error(
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} "
f"in phase {phase}: {e}",
f"API Error during LLM call for {client.model_name}/{power_name}/{response_type} in phase {phase}: {e}",
exc_info=True,
)
# Back-off before the next attempt (unless this was the last)
@ -439,7 +438,6 @@ async def run_llm_and_log(
return raw_response
# This generates a few lines of random alphanum chars to inject into the
# system prompt. This lets us use temp=0 while still getting variation
# between trials.
@ -447,15 +445,8 @@ async def run_llm_and_log(
# ensure valid json outputs.
def generate_random_seed(n_lines: int = 5, n_chars_per_line: int = 80):
# Generate x lines of y random alphanumeric characters
seed_lines = [
''.join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line))
for _ in range(n_lines)
]
random_seed_block = (
"<RANDOM SEED PLEASE IGNORE>\n" +
"\n".join(seed_lines) +
"\n</RANDOM SEED>"
)
seed_lines = ["".join(random.choices(string.ascii_letters + string.digits, k=n_chars_per_line)) for _ in range(n_lines)]
random_seed_block = "<RANDOM SEED PLEASE IGNORE>\n" + "\n".join(seed_lines) + "\n</RANDOM SEED>"
return random_seed_block
@ -473,6 +464,7 @@ def get_prompt_path(prompt_name: str) -> str:
else:
return prompt_name
def normalize_recipient_name(recipient: str) -> str:
"""Normalize recipient names to handle LLM typos and abbreviations."""
if not recipient:
@ -482,24 +474,24 @@ def normalize_recipient_name(recipient: str) -> str:
# Handle common LLM typos and abbreviations found in data
name_mapping = {
'EGMANY': 'GERMANY',
'GERMAN': 'GERMANY',
'UK': 'ENGLAND',
'BRIT': 'ENGLAND',
'ENGLAND': 'ENGLAND', # Keep as-is
'FRANCE': 'FRANCE', # Keep as-is
'GERMANY': 'GERMANY', # Keep as-is
'ITALY': 'ITALY', # Keep as-is
'AUSTRIA': 'AUSTRIA', # Keep as-is
'RUSSIA': 'RUSSIA', # Keep as-is
'TURKEY': 'TURKEY', # Keep as-is
'Germany': 'GERMANY',
'England': 'ENGLAND',
'France': 'FRANCE',
'Italy': 'ITALY',
'Russia': 'RUSSIA',
'Austria': 'AUSTRIA',
'Turkey': 'TURKEY',
"EGMANY": "GERMANY",
"GERMAN": "GERMANY",
"UK": "ENGLAND",
"BRIT": "ENGLAND",
"ENGLAND": "ENGLAND", # Keep as-is
"FRANCE": "FRANCE", # Keep as-is
"GERMANY": "GERMANY", # Keep as-is
"ITALY": "ITALY", # Keep as-is
"AUSTRIA": "AUSTRIA", # Keep as-is
"RUSSIA": "RUSSIA", # Keep as-is
"TURKEY": "TURKEY", # Keep as-is
"Germany": "GERMANY",
"England": "ENGLAND",
"France": "FRANCE",
"Italy": "ITALY",
"Russia": "RUSSIA",
"Austria": "AUSTRIA",
"Turkey": "TURKEY",
}
normalized = name_mapping.get(recipient, recipient)